forked from mindspore-Ecosystem/mindspore
!48484 [MS] c api add control flow api
Merge pull request !48484 from XianglongZeng/c_api_pr
This commit is contained in:
commit
6e010ea825
|
@ -26,9 +26,17 @@ typedef void *TensorHandle;
|
|||
typedef void *NodeHandle;
|
||||
typedef void *AttrHandle;
|
||||
typedef void *GraphHandle;
|
||||
typedef void *FuncGraphManagerHandle;
|
||||
typedef void *FuncGraphMgrHandle;
|
||||
typedef void *ResMgrHandle;
|
||||
|
||||
typedef const void *ConstHandle;
|
||||
typedef const void *ConstTensorHandle;
|
||||
typedef const void *ConstNodeHandle;
|
||||
typedef const void *ConstAttrHandle;
|
||||
typedef const void *ConstGraphHandle;
|
||||
typedef const void *ConstFuncGraphMgrHandle;
|
||||
typedef const void *ConstResMgrHandle;
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -35,7 +35,7 @@ extern "C" {
|
|||
/// \param[in] input_node The input node which contains the Abstract.
|
||||
///
|
||||
/// \return Error code indicates whether the function executed successfully.
|
||||
MIND_C_API STATUS MSAssignAbstract(ResMgrHandle res_mgr, NodeHandle cur_node, NodeHandle input_node);
|
||||
MIND_C_API STATUS MSAssignAbstract(ResMgrHandle res_mgr, NodeHandle cur_node, ConstNodeHandle input_node);
|
||||
|
||||
/// \brief Set Abstract to the node with type and shape.
|
||||
///
|
||||
|
|
|
@ -134,7 +134,8 @@ MIND_C_API STATUS MSOpSetAttrString(ResMgrHandle res_mgr, NodeHandle op, const c
|
|||
/// \param[in] error Error code indicates whether the function executed successfully.
|
||||
///
|
||||
/// \return Attribute value
|
||||
MIND_C_API int64_t MSOpGetScalarAttrInt64(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, STATUS *error);
|
||||
MIND_C_API int64_t MSOpGetScalarAttrInt64(ResMgrHandle res_mgr, ConstNodeHandle op, const char *attr_name,
|
||||
STATUS *error);
|
||||
|
||||
/// \brief Get the attribute of the target node with the given attribute name.
|
||||
///
|
||||
|
@ -145,8 +146,8 @@ MIND_C_API int64_t MSOpGetScalarAttrInt64(ResMgrHandle res_mgr, NodeHandle op, c
|
|||
/// \param[in] value_num Size of the given array.
|
||||
///
|
||||
/// \return Error code indicates whether the function executed successfully.
|
||||
MIND_C_API STATUS MSOpGetAttrArrayInt64(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, int64_t values[],
|
||||
size_t value_num);
|
||||
MIND_C_API STATUS MSOpGetAttrArrayInt64(ResMgrHandle res_mgr, ConstNodeHandle op, const char *attr_name,
|
||||
int64_t values[], size_t value_num);
|
||||
|
||||
/// \brief Create new Int64 attribute scalar value.
|
||||
///
|
||||
|
|
|
@ -43,7 +43,7 @@ MIND_C_API GraphHandle MSFuncGraphCreate(ResMgrHandle res_mgr);
|
|||
/// \param[in] i Index of the input node.
|
||||
///
|
||||
/// \return The created function graph.
|
||||
MIND_C_API NodeHandle MSFuncGraphGetInput(ResMgrHandle res_mgr, const GraphHandle graph, size_t i);
|
||||
MIND_C_API NodeHandle MSFuncGraphGetInput(ResMgrHandle res_mgr, ConstGraphHandle graph, size_t i);
|
||||
|
||||
/// \brief Get the inputs number of the function graph.
|
||||
///
|
||||
|
@ -52,7 +52,7 @@ MIND_C_API NodeHandle MSFuncGraphGetInput(ResMgrHandle res_mgr, const GraphHandl
|
|||
/// \param[in] error Records error code that indicate whether the functions executed successfully.
|
||||
///
|
||||
/// \return the inputs number of the function graph.
|
||||
MIND_C_API size_t MSFuncGraphGetInputNum(ResMgrHandle res_mgr, const GraphHandle graph, STATUS *error);
|
||||
MIND_C_API size_t MSFuncGraphGetInputNum(ResMgrHandle res_mgr, ConstGraphHandle graph, STATUS *error);
|
||||
|
||||
/// \brief Get all inputs of the function graph.
|
||||
///
|
||||
|
@ -62,7 +62,7 @@ MIND_C_API size_t MSFuncGraphGetInputNum(ResMgrHandle res_mgr, const GraphHandle
|
|||
/// \param[in] input_num The length of the array.
|
||||
///
|
||||
/// \return Error code that indicate whether the functions executed successfully.
|
||||
MIND_C_API STATUS MSFuncGraphGetInputs(ResMgrHandle res_mgr, const GraphHandle graph, NodeHandle inputs[],
|
||||
MIND_C_API STATUS MSFuncGraphGetInputs(ResMgrHandle res_mgr, ConstGraphHandle graph, NodeHandle inputs[],
|
||||
size_t input_num);
|
||||
|
||||
/// \brief Set the output node.
|
||||
|
@ -73,7 +73,7 @@ MIND_C_API STATUS MSFuncGraphGetInputs(ResMgrHandle res_mgr, const GraphHandle g
|
|||
/// \param[in] force_new_ret If true, a new return node is always created.
|
||||
///
|
||||
/// \return Error code that indicate whether the functions executed successfully.
|
||||
MIND_C_API STATUS MSFuncGraphSetOutput(ResMgrHandle res_mgr, GraphHandle graph, const NodeHandle op_node,
|
||||
MIND_C_API STATUS MSFuncGraphSetOutput(ResMgrHandle res_mgr, GraphHandle graph, ConstNodeHandle op_node,
|
||||
bool force_new_ret);
|
||||
|
||||
/// \brief Set the output node.
|
||||
|
@ -84,7 +84,7 @@ MIND_C_API STATUS MSFuncGraphSetOutput(ResMgrHandle res_mgr, GraphHandle graph,
|
|||
/// \param[in] force_new_ret If true, a new return node is always created.
|
||||
///
|
||||
/// \return Error code that indicate whether the functions executed successfully.
|
||||
MIND_C_API STATUS MSFuncGraphSetOutputs(ResMgrHandle res_mgr, GraphHandle graph, const Handle outputs[],
|
||||
MIND_C_API STATUS MSFuncGraphSetOutputs(ResMgrHandle res_mgr, GraphHandle graph, Handle const outputs[],
|
||||
size_t output_num, bool force_new_ret);
|
||||
|
||||
/// \brief Get the output node according to the index.
|
||||
|
@ -94,7 +94,7 @@ MIND_C_API STATUS MSFuncGraphSetOutputs(ResMgrHandle res_mgr, GraphHandle graph,
|
|||
/// \param[in] i The index to get the output. If there is only one output for graph, the i should be 0;
|
||||
///
|
||||
/// \return The output node, nullptr if output not set.
|
||||
MIND_C_API NodeHandle MSFuncGraphGetOutput(ResMgrHandle res_mgr, const GraphHandle graph, size_t i);
|
||||
MIND_C_API NodeHandle MSFuncGraphGetOutput(ResMgrHandle res_mgr, ConstGraphHandle graph, size_t i);
|
||||
|
||||
/// \brief Get the outputs number of the function graph.
|
||||
///
|
||||
|
@ -103,7 +103,7 @@ MIND_C_API NodeHandle MSFuncGraphGetOutput(ResMgrHandle res_mgr, const GraphHand
|
|||
/// \param[in] error Records error code that indicate whether the functions executed successfully.
|
||||
///
|
||||
/// \return the outputs number of the function graph.
|
||||
MIND_C_API size_t MSFuncGraphGetOutputNum(ResMgrHandle res_mgr, GraphHandle graph, STATUS *error);
|
||||
MIND_C_API size_t MSFuncGraphGetOutputNum(ResMgrHandle res_mgr, ConstGraphHandle graph, STATUS *error);
|
||||
|
||||
/// \brief Get all outputs of the function graph.
|
||||
///
|
||||
|
@ -113,7 +113,7 @@ MIND_C_API size_t MSFuncGraphGetOutputNum(ResMgrHandle res_mgr, GraphHandle grap
|
|||
/// \param[in] output_num The length of the array.
|
||||
///
|
||||
/// \return Error code that indicate whether the functions executed successfully.
|
||||
MIND_C_API STATUS MSFuncGraphGetOutputs(ResMgrHandle res_mgr, const GraphHandle graph, NodeHandle outputs[],
|
||||
MIND_C_API STATUS MSFuncGraphGetOutputs(ResMgrHandle res_mgr, ConstGraphHandle graph, NodeHandle outputs[],
|
||||
size_t output_num);
|
||||
|
||||
/// \brief Replace a node in a function graph.
|
||||
|
@ -124,8 +124,8 @@ MIND_C_API STATUS MSFuncGraphGetOutputs(ResMgrHandle res_mgr, const GraphHandle
|
|||
/// \param[in] new_node The replace node.
|
||||
///
|
||||
/// \return Error code that indicate whether the functions executed successfully.
|
||||
MIND_C_API STATUS MSFuncGraphReplace(ResMgrHandle res_mgr, GraphHandle graph, const NodeHandle old_node,
|
||||
const NodeHandle new_node);
|
||||
MIND_C_API STATUS MSFuncGraphReplace(ResMgrHandle res_mgr, GraphHandle graph, ConstNodeHandle old_node,
|
||||
ConstNodeHandle new_node);
|
||||
|
||||
/// \brief Compile the function graph.
|
||||
///
|
||||
|
@ -145,7 +145,7 @@ MIND_C_API STATUS MSFuncGraphCompile(ResMgrHandle res_mgr, GraphHandle graph);
|
|||
/// \param[in] outputs_num The output size.
|
||||
///
|
||||
/// \return Error code that indicate whether the function graph executed successfully.
|
||||
MIND_C_API STATUS MSFuncGraphRun(ResMgrHandle res_mgr, GraphHandle graph, const TensorHandle inputs[], size_t input_num,
|
||||
MIND_C_API STATUS MSFuncGraphRun(ResMgrHandle res_mgr, GraphHandle graph, TensorHandle const inputs[], size_t input_num,
|
||||
TensorHandle outputs[], size_t outputs_num);
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -42,7 +42,7 @@ extern "C" {
|
|||
/// \param[in] attr_num The number of attributes.
|
||||
///
|
||||
/// \return The created Operator node handle
|
||||
MIND_C_API NodeHandle MSNewOp(ResMgrHandle res_mgr, GraphHandle graph, const char *op_type, const Handle inputs[],
|
||||
MIND_C_API NodeHandle MSNewOp(ResMgrHandle res_mgr, GraphHandle graph, const char *op_type, Handle const inputs[],
|
||||
size_t input_num, char **attr_names, AttrHandle attrs[], size_t attr_num);
|
||||
|
||||
/// \brief Pack nodes into a Tuple node.
|
||||
|
@ -53,9 +53,9 @@ MIND_C_API NodeHandle MSNewOp(ResMgrHandle res_mgr, GraphHandle graph, const cha
|
|||
/// \param[in] node_num The number of nodes in the array.
|
||||
///
|
||||
/// \return The created Tuple node handle.
|
||||
MIND_C_API NodeHandle MSPackNodesTuple(ResMgrHandle res_mgr, GraphHandle graph, const Handle nodes[], size_t node_num);
|
||||
MIND_C_API NodeHandle MSPackNodesTuple(ResMgrHandle res_mgr, GraphHandle graph, Handle const nodes[], size_t node_num);
|
||||
|
||||
/// \brief Get specified output branch from as multi-output Operator.
|
||||
/// \brief Get specified output branch from a multi-output Operator.
|
||||
///
|
||||
/// \param[in] res_mgr Resource manager that saves allocated instance resources.
|
||||
/// \param[in] graph The given function graph pointer handle.
|
||||
|
@ -63,7 +63,31 @@ MIND_C_API NodeHandle MSPackNodesTuple(ResMgrHandle res_mgr, GraphHandle graph,
|
|||
/// \param[in] i The index of the output branch.
|
||||
///
|
||||
/// \return The obtained output node.
|
||||
MIND_C_API NodeHandle MSOpGetSpecOutput(ResMgrHandle res_mgr, GraphHandle graph, const NodeHandle op, size_t i);
|
||||
MIND_C_API NodeHandle MSOpGetSpecOutput(ResMgrHandle res_mgr, GraphHandle graph, ConstNodeHandle op, size_t i);
|
||||
|
||||
/// \brief Create a Switch operator for control-flow scene.
|
||||
///
|
||||
/// \param[in] res_mgr Resource manager that saves allocated instance resources.
|
||||
/// \param[in] graph The given function graph pointer handle.
|
||||
/// \param[in] cond The condition of Switch which can be an Operator or a Subgraph.
|
||||
/// \param[in] true_br The true branch of Switch which must be a Subgraph.
|
||||
/// \param[in] false_br The false branch of Switch which must be a Subgraph.
|
||||
///
|
||||
/// \return The created Switch operator node.
|
||||
MIND_C_API NodeHandle MSNewSwitch(ResMgrHandle res_mgr, GraphHandle graph, Handle cond, ConstGraphHandle true_br,
|
||||
ConstGraphHandle false_br);
|
||||
|
||||
/// \brief Create a While operator for control-flow scene.
|
||||
///
|
||||
/// \param[in] res_mgr Resource manager that saves allocated instance resources.
|
||||
/// \param[in] graph The given function graph pointer handle.
|
||||
/// \param[in] cond The condition of While which can be an Operator or a Subgraph.
|
||||
/// \param[in] body_graph The loop body of While which must be a Subgraph.
|
||||
/// \param[in] after_graph The graph after stepping out the While which must be a Subgraph.
|
||||
///
|
||||
/// \return The created While operator node.
|
||||
MIND_C_API NodeHandle MSNewWhile(ResMgrHandle res_mgr, GraphHandle graph, Handle cond, GraphHandle body_graph,
|
||||
GraphHandle after_graph);
|
||||
|
||||
/// \brief Get specified input node of Operator.
|
||||
///
|
||||
|
@ -72,7 +96,7 @@ MIND_C_API NodeHandle MSOpGetSpecOutput(ResMgrHandle res_mgr, GraphHandle graph,
|
|||
/// \param[in] i The index of the input.
|
||||
///
|
||||
/// \return The obtained input node handle.
|
||||
MIND_C_API NodeHandle MSOpGetInput(ResMgrHandle res_mgr, const NodeHandle op, size_t i);
|
||||
MIND_C_API NodeHandle MSOpGetInput(ResMgrHandle res_mgr, ConstNodeHandle op, size_t i);
|
||||
|
||||
/// \brief Get the input nodes number of Operator.
|
||||
///
|
||||
|
@ -81,7 +105,7 @@ MIND_C_API NodeHandle MSOpGetInput(ResMgrHandle res_mgr, const NodeHandle op, si
|
|||
/// \param[in] error Records error code that indicate whether the functions executed successfully.
|
||||
///
|
||||
/// \return The input nodes number.
|
||||
MIND_C_API size_t MSOpGetInputsNum(ResMgrHandle res_mgr, const NodeHandle op, STATUS *error);
|
||||
MIND_C_API size_t MSOpGetInputsNum(ResMgrHandle res_mgr, ConstNodeHandle op, STATUS *error);
|
||||
|
||||
/// \brief Get all input nodes of the Operator.
|
||||
///
|
||||
|
@ -91,7 +115,7 @@ MIND_C_API size_t MSOpGetInputsNum(ResMgrHandle res_mgr, const NodeHandle op, ST
|
|||
/// \param[in] input_num The size of the input array.
|
||||
///
|
||||
/// \return Error code that indicate whether the functions executed successfully.
|
||||
MIND_C_API STATUS MSOpGetInputs(ResMgrHandle res_mgr, const NodeHandle op, NodeHandle inputs[], size_t input_num);
|
||||
MIND_C_API STATUS MSOpGetInputs(ResMgrHandle res_mgr, ConstNodeHandle op, NodeHandle inputs[], size_t input_num);
|
||||
|
||||
/// \brief Create a subgraph node.
|
||||
///
|
||||
|
@ -102,8 +126,8 @@ MIND_C_API STATUS MSOpGetInputs(ResMgrHandle res_mgr, const NodeHandle op, NodeH
|
|||
/// \param[in] input_num The number of the input array.
|
||||
///
|
||||
/// \return The created subgraph node handle.
|
||||
MIND_C_API NodeHandle MSNewSubGraphNode(ResMgrHandle res_mgr, GraphHandle graph, GraphHandle sub_graph,
|
||||
const Handle inputs[], size_t input_num);
|
||||
MIND_C_API NodeHandle MSNewFuncCallNode(ResMgrHandle res_mgr, GraphHandle graph, ConstGraphHandle sub_graph,
|
||||
Handle const inputs[], size_t input_num);
|
||||
|
||||
/// \brief Create a Placeholder node, which is usually the input of graph without data.
|
||||
///
|
||||
|
@ -138,7 +162,7 @@ MIND_C_API NodeHandle MSNewTensorVariable(ResMgrHandle res_mgr, GraphHandle grap
|
|||
/// \param[in] tensor The given Tensor instance.
|
||||
///
|
||||
/// \return The created Variable node handle.
|
||||
MIND_C_API NodeHandle MSNewTensorVariableFromTensor(ResMgrHandle res_mgr, GraphHandle graph, TensorHandle tensor);
|
||||
MIND_C_API NodeHandle MSNewTensorVariableFromTensor(ResMgrHandle res_mgr, GraphHandle graph, ConstTensorHandle tensor);
|
||||
|
||||
/// \brief Get data size of a Tensor Variable.
|
||||
///
|
||||
|
@ -147,7 +171,7 @@ MIND_C_API NodeHandle MSNewTensorVariableFromTensor(ResMgrHandle res_mgr, GraphH
|
|||
/// \param[in] error Records error code that indicate whether the functions executed successfully.
|
||||
///
|
||||
/// \return The data byte size.
|
||||
MIND_C_API size_t MSTensorVariableGetDataSize(ResMgrHandle res_mgr, NodeHandle node, STATUS *error);
|
||||
MIND_C_API size_t MSTensorVariableGetDataSize(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error);
|
||||
|
||||
/// \brief Get data from a Tensor Variable.
|
||||
///
|
||||
|
@ -155,7 +179,7 @@ MIND_C_API size_t MSTensorVariableGetDataSize(ResMgrHandle res_mgr, NodeHandle n
|
|||
/// \param[in] node The tensor variable node
|
||||
///
|
||||
/// \return The data.
|
||||
MIND_C_API void *MSTensorVariableGetData(ResMgrHandle res_mgr, NodeHandle node);
|
||||
MIND_C_API void *MSTensorVariableGetData(ResMgrHandle res_mgr, ConstNodeHandle node);
|
||||
|
||||
/// \brief Create a Constant node of tensor, which contains constant tensor data.
|
||||
///
|
||||
|
@ -185,7 +209,7 @@ MIND_C_API NodeHandle MSNewTensorConstantFromTensor(ResMgrHandle res_mgr, Tensor
|
|||
/// \param[in] error Records error code that indicate whether the functions executed successfully.
|
||||
///
|
||||
/// \return The data byte size.
|
||||
MIND_C_API size_t MSTensorConstantGetDataSize(ResMgrHandle res_mgr, NodeHandle node, STATUS *error);
|
||||
MIND_C_API size_t MSTensorConstantGetDataSize(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error);
|
||||
|
||||
/// \brief Get data from a Tensor Constant.
|
||||
///
|
||||
|
@ -193,7 +217,7 @@ MIND_C_API size_t MSTensorConstantGetDataSize(ResMgrHandle res_mgr, NodeHandle n
|
|||
/// \param[in] node The tensor constant node
|
||||
///
|
||||
/// \return The data.
|
||||
MIND_C_API void *MSTensorConstantGetData(ResMgrHandle res_mgr, NodeHandle node);
|
||||
MIND_C_API void *MSTensorConstantGetData(ResMgrHandle res_mgr, ConstNodeHandle node);
|
||||
|
||||
/// \brief Create Constant node of a float scalar.
|
||||
///
|
||||
|
@ -259,7 +283,7 @@ MIND_C_API NodeHandle MSNewTypeConstant(ResMgrHandle res_mgr, TypeId type);
|
|||
/// \param[in] error Records error code that indicate whether the functions executed successfully.
|
||||
///
|
||||
/// \return The obtained int32 value.
|
||||
MIND_C_API int MSScalarConstantGetValueInt32(ResMgrHandle res_mgr, const NodeHandle node, STATUS *error);
|
||||
MIND_C_API int MSScalarConstantGetValueInt32(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error);
|
||||
|
||||
/// \brief Get value from the float32 scalar Constant node.
|
||||
///
|
||||
|
@ -268,7 +292,7 @@ MIND_C_API int MSScalarConstantGetValueInt32(ResMgrHandle res_mgr, const NodeHan
|
|||
/// \param[in] error Records error code that indicate whether the functions executed successfully.
|
||||
///
|
||||
/// \return The obtained float32 value.
|
||||
MIND_C_API float MSScalarConstantGetValueFloat32(ResMgrHandle res_mgr, const NodeHandle node, STATUS *error);
|
||||
MIND_C_API float MSScalarConstantGetValueFloat32(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error);
|
||||
|
||||
/// \brief Get value from the bool scalar Constant node.
|
||||
///
|
||||
|
@ -277,7 +301,7 @@ MIND_C_API float MSScalarConstantGetValueFloat32(ResMgrHandle res_mgr, const Nod
|
|||
/// \param[in] error Records error code that indicate whether the functions executed successfully.
|
||||
///
|
||||
/// \return The obtained bool value.
|
||||
MIND_C_API bool MSScalarConstantGetValueBool(ResMgrHandle res_mgr, const NodeHandle node, STATUS *error);
|
||||
MIND_C_API bool MSScalarConstantGetValueBool(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error);
|
||||
|
||||
/// \brief Get value from the int64 scalar Constant node.
|
||||
///
|
||||
|
@ -286,7 +310,7 @@ MIND_C_API bool MSScalarConstantGetValueBool(ResMgrHandle res_mgr, const NodeHan
|
|||
/// \param[in] error Records error code that indicate whether the functions executed successfully.
|
||||
///
|
||||
/// \return The obtained int64 value.
|
||||
MIND_C_API int64_t MSScalarConstantGetValueInt64(ResMgrHandle res_mgr, const NodeHandle node, STATUS *error);
|
||||
MIND_C_API int64_t MSScalarConstantGetValueInt64(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error);
|
||||
|
||||
/// \brief Get value from the string Constant node.
|
||||
///
|
||||
|
@ -296,7 +320,7 @@ MIND_C_API int64_t MSScalarConstantGetValueInt64(ResMgrHandle res_mgr, const Nod
|
|||
/// \param[in] str_len The size of the char array.
|
||||
///
|
||||
/// \return The error code that indicate whether the functions executed successfully.
|
||||
MIND_C_API STATUS MSStringConstantGetValue(ResMgrHandle res_mgr, const NodeHandle node, char str_buf[], size_t str_len);
|
||||
MIND_C_API STATUS MSStringConstantGetValue(ResMgrHandle res_mgr, ConstNodeHandle node, char str_buf[], size_t str_len);
|
||||
|
||||
/// \brief Get value from the tuple Constant node.
|
||||
///
|
||||
|
@ -305,7 +329,7 @@ MIND_C_API STATUS MSStringConstantGetValue(ResMgrHandle res_mgr, const NodeHandl
|
|||
/// \param[in] error Records error code that indicate whether the functions executed successfully.
|
||||
///
|
||||
/// \return The size of the Tuple.
|
||||
MIND_C_API size_t MSTupleConstantGetSize(ResMgrHandle res_mgr, const NodeHandle node, STATUS *error);
|
||||
MIND_C_API size_t MSTupleConstantGetSize(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error);
|
||||
|
||||
/// \brief Get value from the Tuple Constant node.
|
||||
///
|
||||
|
@ -315,7 +339,7 @@ MIND_C_API size_t MSTupleConstantGetSize(ResMgrHandle res_mgr, const NodeHandle
|
|||
/// \param[in] size The size of the value vector.
|
||||
///
|
||||
/// \return The error code that indicate whether the functions executed successfully.
|
||||
MIND_C_API STATUS MSTupleConstantGetValueInt64(ResMgrHandle res_mgr, const NodeHandle node, int64_t vec[], size_t size);
|
||||
MIND_C_API STATUS MSTupleConstantGetValueInt64(ResMgrHandle res_mgr, ConstNodeHandle node, int64_t vec[], size_t size);
|
||||
|
||||
/// \brief Get value from the Type Constant node.
|
||||
///
|
||||
|
@ -324,7 +348,7 @@ MIND_C_API STATUS MSTupleConstantGetValueInt64(ResMgrHandle res_mgr, const NodeH
|
|||
/// \param[in] error Records error code that indicate whether the functions executed successfully.
|
||||
///
|
||||
/// \return The obtained type value.
|
||||
MIND_C_API TypeId MSTypeConstantGetValue(ResMgrHandle res_mgr, const NodeHandle node, STATUS *error);
|
||||
MIND_C_API TypeId MSTypeConstantGetValue(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error);
|
||||
|
||||
/// \brief Set Operator node name.
|
||||
///
|
||||
|
@ -333,7 +357,7 @@ MIND_C_API TypeId MSTypeConstantGetValue(ResMgrHandle res_mgr, const NodeHandle
|
|||
/// \param[in] name The op node name to be set.
|
||||
///
|
||||
/// \return The error code that indicate whether the functions executed successfully.
|
||||
MIND_C_API STATUS MSOpSetName(ResMgrHandle res_mgr, const NodeHandle node, const char *name);
|
||||
MIND_C_API STATUS MSOpSetName(ResMgrHandle res_mgr, NodeHandle node, const char *name);
|
||||
|
||||
/// \brief Get the name of node.
|
||||
///
|
||||
|
@ -343,7 +367,7 @@ MIND_C_API STATUS MSOpSetName(ResMgrHandle res_mgr, const NodeHandle node, const
|
|||
/// \param[in] str_len The size of the char array.
|
||||
///
|
||||
/// \return The error code that indicate whether the functions executed successfully.
|
||||
MIND_C_API STATUS MSNodeGetName(ResMgrHandle res_mgr, const NodeHandle node, char str_buf[], size_t str_len);
|
||||
MIND_C_API STATUS MSNodeGetName(ResMgrHandle res_mgr, ConstNodeHandle node, char str_buf[], size_t str_len);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -73,7 +73,7 @@ MIND_C_API TensorHandle MSNewTensorWithSrcType(ResMgrHandle res_mgr, void *data,
|
|||
/// \param[in] tensor The pointer of the tensor instance.
|
||||
///
|
||||
/// \return The pointer to the tensor data
|
||||
MIND_C_API void *MSTensorGetData(ResMgrHandle res_mgr, const TensorHandle tensor);
|
||||
MIND_C_API void *MSTensorGetData(ResMgrHandle res_mgr, ConstTensorHandle tensor);
|
||||
|
||||
/// \brief Set tensor data type.
|
||||
///
|
||||
|
@ -82,7 +82,7 @@ MIND_C_API void *MSTensorGetData(ResMgrHandle res_mgr, const TensorHandle tensor
|
|||
/// \param[in] type The data type to be set.
|
||||
///
|
||||
/// \return Error code that indicate whether the functions executed successfully.
|
||||
MIND_C_API STATUS MSTensorSetDataType(ResMgrHandle res_mgr, const TensorHandle tensor, TypeId type);
|
||||
MIND_C_API STATUS MSTensorSetDataType(ResMgrHandle res_mgr, TensorHandle tensor, TypeId type);
|
||||
|
||||
/// \brief Get tensor data type.
|
||||
///
|
||||
|
@ -90,7 +90,7 @@ MIND_C_API STATUS MSTensorSetDataType(ResMgrHandle res_mgr, const TensorHandle t
|
|||
/// \param[in] tensor The pointer of the tensor instance.
|
||||
///
|
||||
/// \return The data type of tensor.
|
||||
MIND_C_API TypeId MSTensorGetDataType(ResMgrHandle res_mgr, const TensorHandle tensor, STATUS *error);
|
||||
MIND_C_API TypeId MSTensorGetDataType(ResMgrHandle res_mgr, ConstTensorHandle tensor, STATUS *error);
|
||||
|
||||
/// \brief Get the byte size of tensor data.
|
||||
///
|
||||
|
@ -99,7 +99,7 @@ MIND_C_API TypeId MSTensorGetDataType(ResMgrHandle res_mgr, const TensorHandle t
|
|||
/// \param[in] error Records error code that indicate whether the functions executed successfully.
|
||||
///
|
||||
/// \return The byte size of tensor data.
|
||||
MIND_C_API size_t MSTensorGetDataSize(ResMgrHandle res_mgr, const TensorHandle tensor, STATUS *error);
|
||||
MIND_C_API size_t MSTensorGetDataSize(ResMgrHandle res_mgr, ConstTensorHandle tensor, STATUS *error);
|
||||
|
||||
/// \brief Get the element number of tensor array.
|
||||
///
|
||||
|
@ -108,7 +108,7 @@ MIND_C_API size_t MSTensorGetDataSize(ResMgrHandle res_mgr, const TensorHandle t
|
|||
/// \param[in] error Records error code that indicate whether the functions executed successfully.
|
||||
///
|
||||
/// \return The element number of tensor array.
|
||||
MIND_C_API size_t MSTensorGetElementNum(ResMgrHandle res_mgr, const TensorHandle tensor, STATUS *error);
|
||||
MIND_C_API size_t MSTensorGetElementNum(ResMgrHandle res_mgr, ConstTensorHandle tensor, STATUS *error);
|
||||
|
||||
/// \brief Get the dimension of tensor.
|
||||
///
|
||||
|
@ -117,7 +117,7 @@ MIND_C_API size_t MSTensorGetElementNum(ResMgrHandle res_mgr, const TensorHandle
|
|||
/// \param[in] error Records error code that indicate whether the functions executed successfully.
|
||||
///
|
||||
/// \return The dimension of tensor.
|
||||
MIND_C_API size_t MSTensorGetDimension(ResMgrHandle res_mgr, const TensorHandle tensor, STATUS *error);
|
||||
MIND_C_API size_t MSTensorGetDimension(ResMgrHandle res_mgr, ConstTensorHandle tensor, STATUS *error);
|
||||
|
||||
/// \brief Set the shape of tensor array.
|
||||
///
|
||||
|
@ -127,7 +127,7 @@ MIND_C_API size_t MSTensorGetDimension(ResMgrHandle res_mgr, const TensorHandle
|
|||
/// \param[in] dim The the dimension of tensor, i.e., size of shape array.
|
||||
///
|
||||
/// \return Error code indicates whether the function executed successfully.
|
||||
MIND_C_API STATUS MSTensorSetShape(ResMgrHandle res_mgr, const TensorHandle tensor, int64_t shape[], size_t dim);
|
||||
MIND_C_API STATUS MSTensorSetShape(ResMgrHandle res_mgr, TensorHandle tensor, const int64_t shape[], size_t dim);
|
||||
|
||||
/// \brief Get the shape of tensor array.
|
||||
///
|
||||
|
@ -137,7 +137,7 @@ MIND_C_API STATUS MSTensorSetShape(ResMgrHandle res_mgr, const TensorHandle tens
|
|||
/// \param[in] dim The the dimension of tensor, i.e., size of shape array.
|
||||
///
|
||||
/// \return Error code indicates whether the function executed successfully.
|
||||
MIND_C_API STATUS MSTensorGetShape(ResMgrHandle res_mgr, const TensorHandle tensor, int64_t shape[], size_t dim);
|
||||
MIND_C_API STATUS MSTensorGetShape(ResMgrHandle res_mgr, ConstTensorHandle tensor, int64_t shape[], size_t dim);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@
|
|||
#include "abstract/dshape.h"
|
||||
#include "ir/dtype.h"
|
||||
|
||||
STATUS MSAssignAbstract(ResMgrHandle res_mgr, NodeHandle cur_node, NodeHandle input_node) {
|
||||
STATUS MSAssignAbstract(ResMgrHandle res_mgr, NodeHandle cur_node, ConstNodeHandle input_node) {
|
||||
if (res_mgr == nullptr || cur_node == nullptr || input_node == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [inputs] are nullptr.";
|
||||
return RET_NULL_PTR;
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include "c_api/src/common.h"
|
||||
#include "ir/tensor.h"
|
||||
|
||||
PrimitivePtr GetOpPrim(ResMgrHandle res_mgr, NodeHandle node) {
|
||||
PrimitivePtr GetOpPrim(ResMgrHandle res_mgr, ConstNodeHandle node) {
|
||||
auto src_node = GetSrcPtr<CNodePtr>(res_mgr, node);
|
||||
auto node_input = src_node->input(0);
|
||||
if (node_input == nullptr) {
|
||||
|
@ -213,7 +213,7 @@ STATUS MSOpSetAttrString(ResMgrHandle res_mgr, NodeHandle op, const char *attr_n
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
int64_t MSOpGetScalarAttrInt64(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, STATUS *error) {
|
||||
int64_t MSOpGetScalarAttrInt64(ResMgrHandle res_mgr, ConstNodeHandle op, const char *attr_name, STATUS *error) {
|
||||
if (error == nullptr) {
|
||||
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
|
||||
return 0;
|
||||
|
@ -240,7 +240,7 @@ int64_t MSOpGetScalarAttrInt64(ResMgrHandle res_mgr, NodeHandle op, const char *
|
|||
}
|
||||
}
|
||||
|
||||
STATUS MSOpGetAttrArrayInt64(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, int64_t values[],
|
||||
STATUS MSOpGetAttrArrayInt64(ResMgrHandle res_mgr, ConstNodeHandle op, const char *attr_name, int64_t values[],
|
||||
size_t value_num) {
|
||||
if (res_mgr == nullptr || op == nullptr || attr_name == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] is nullptr.";
|
||||
|
|
|
@ -18,12 +18,14 @@
|
|||
#define MINDSPORE_CCSRC_C_API_SRC_COMMON_H_
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ir/func_graph.h"
|
||||
#include "ops/primitive_c.h"
|
||||
|
||||
using FuncGraphImpl = mindspore::FuncGraph;
|
||||
using FuncGraphManagerImpl = mindspore::FuncGraphManager;
|
||||
using AnfNodeImpl = mindspore::AnfNode;
|
||||
using ParameterImpl = mindspore::Parameter;
|
||||
using ValueNodeImpl = mindspore::ValueNode;
|
||||
using CNodeImpl = mindspore::CNode;
|
||||
using PrimitiveImpl = mindspore::Primitive;
|
||||
|
@ -64,6 +66,7 @@ using CNodePtr = mindspore::CNodePtr;
|
|||
using ParameterPtr = mindspore::ParameterPtr;
|
||||
using AbstractBasePtr = mindspore::abstract::AbstractBasePtr;
|
||||
using FuncGraphPtr = mindspore::FuncGraphPtr;
|
||||
using FuncGraphManagerPtr = std::shared_ptr<mindspore::FuncGraphManager>;
|
||||
|
||||
using AttrMap = mindspore::HashMap<std::string, ValuePtr>;
|
||||
|
||||
|
|
|
@ -21,8 +21,11 @@
|
|||
#include "base/base.h"
|
||||
#include "ops/core_ops.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "utils/ms_context.h"
|
||||
#include "backend/graph_compiler/backend.h"
|
||||
#include "pipeline/jit/pass.h"
|
||||
|
||||
GraphHandle MSFuncGraphCreate(ResMgrHandle res_mgr) {
|
||||
if (res_mgr == nullptr) {
|
||||
|
@ -33,7 +36,7 @@ GraphHandle MSFuncGraphCreate(ResMgrHandle res_mgr) {
|
|||
return GetRawPtr(res_mgr, fg);
|
||||
}
|
||||
|
||||
NodeHandle MSFuncGraphGetInput(ResMgrHandle res_mgr, const GraphHandle graph, size_t i) {
|
||||
NodeHandle MSFuncGraphGetInput(ResMgrHandle res_mgr, ConstGraphHandle graph, size_t i) {
|
||||
if (res_mgr == nullptr || graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [cnode] is nullptr.";
|
||||
return nullptr;
|
||||
|
@ -53,7 +56,7 @@ NodeHandle MSFuncGraphGetInput(ResMgrHandle res_mgr, const GraphHandle graph, si
|
|||
}
|
||||
}
|
||||
|
||||
size_t MSFuncGraphGetInputNum(ResMgrHandle res_mgr, const GraphHandle graph, STATUS *error) {
|
||||
size_t MSFuncGraphGetInputNum(ResMgrHandle res_mgr, ConstGraphHandle graph, STATUS *error) {
|
||||
if (error == nullptr) {
|
||||
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
|
||||
return 0;
|
||||
|
@ -77,7 +80,7 @@ size_t MSFuncGraphGetInputNum(ResMgrHandle res_mgr, const GraphHandle graph, STA
|
|||
return input_num;
|
||||
}
|
||||
|
||||
STATUS MSFuncGraphGetInputs(ResMgrHandle res_mgr, const GraphHandle graph, NodeHandle inputs[], size_t input_num) {
|
||||
STATUS MSFuncGraphGetInputs(ResMgrHandle res_mgr, ConstGraphHandle graph, NodeHandle inputs[], size_t input_num) {
|
||||
if (res_mgr == nullptr || graph == nullptr || inputs == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [inputs] is nullptr.";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -100,7 +103,7 @@ STATUS MSFuncGraphGetInputs(ResMgrHandle res_mgr, const GraphHandle graph, NodeH
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS MSFuncGraphSetOutput(ResMgrHandle res_mgr, GraphHandle graph, const NodeHandle op_node, bool force_new_ret) {
|
||||
STATUS MSFuncGraphSetOutput(ResMgrHandle res_mgr, GraphHandle graph, ConstNodeHandle op_node, bool force_new_ret) {
|
||||
if (res_mgr == nullptr || graph == nullptr || op_node == nullptr) {
|
||||
MS_LOG(ERROR) << "Input GraphHandle [res_mgr] or [graph] or [op_node] is nullptr.";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -118,7 +121,7 @@ STATUS MSFuncGraphSetOutput(ResMgrHandle res_mgr, GraphHandle graph, const NodeH
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS MSFuncGraphSetOutputs(ResMgrHandle res_mgr, GraphHandle graph, const Handle outputs[], size_t output_num,
|
||||
STATUS MSFuncGraphSetOutputs(ResMgrHandle res_mgr, GraphHandle graph, Handle const outputs[], size_t output_num,
|
||||
bool force_new_ret) {
|
||||
if (res_mgr == nullptr || graph == nullptr || outputs == nullptr) {
|
||||
MS_LOG(ERROR) << "Input GraphHandle [res_mgr] or [graph] or [outputs] is nullptr.";
|
||||
|
@ -146,7 +149,7 @@ STATUS MSFuncGraphSetOutputs(ResMgrHandle res_mgr, GraphHandle graph, const Hand
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
NodeHandle MSFuncGraphGetOutput(ResMgrHandle res_mgr, const GraphHandle graph, size_t i) {
|
||||
NodeHandle MSFuncGraphGetOutput(ResMgrHandle res_mgr, ConstGraphHandle graph, size_t i) {
|
||||
if (res_mgr == nullptr || graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] is nullptr.";
|
||||
return nullptr;
|
||||
|
@ -178,7 +181,7 @@ NodeHandle MSFuncGraphGetOutput(ResMgrHandle res_mgr, const GraphHandle graph, s
|
|||
}
|
||||
}
|
||||
|
||||
size_t MSFuncGraphGetOutputNum(ResMgrHandle res_mgr, GraphHandle graph, STATUS *error) {
|
||||
size_t MSFuncGraphGetOutputNum(ResMgrHandle res_mgr, ConstGraphHandle graph, STATUS *error) {
|
||||
if (error == nullptr) {
|
||||
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
|
||||
return 0;
|
||||
|
@ -208,7 +211,7 @@ size_t MSFuncGraphGetOutputNum(ResMgrHandle res_mgr, GraphHandle graph, STATUS *
|
|||
return out_num;
|
||||
}
|
||||
|
||||
STATUS MSFuncGraphGetOutputs(ResMgrHandle res_mgr, const GraphHandle graph, NodeHandle outputs[], size_t output_num) {
|
||||
STATUS MSFuncGraphGetOutputs(ResMgrHandle res_mgr, ConstGraphHandle graph, NodeHandle outputs[], size_t output_num) {
|
||||
if (res_mgr == nullptr || graph == nullptr || outputs == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [inputs] is nullptr.";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -239,8 +242,7 @@ STATUS MSFuncGraphGetOutputs(ResMgrHandle res_mgr, const GraphHandle graph, Node
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS MSFuncGraphReplace(ResMgrHandle res_mgr, GraphHandle graph, const NodeHandle old_node,
|
||||
const NodeHandle new_node) {
|
||||
STATUS MSFuncGraphReplace(ResMgrHandle res_mgr, GraphHandle graph, ConstNodeHandle old_node, ConstNodeHandle new_node) {
|
||||
if (res_mgr == nullptr || graph == nullptr || old_node == nullptr || new_node == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [old_node] or [new_node] is nullptr.";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -272,9 +274,11 @@ STATUS MSFuncGraphCompile(ResMgrHandle res_mgr, GraphHandle graph) {
|
|||
try {
|
||||
auto func_graph = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
std::vector<FuncGraphPtr> func_graphs = {func_graph};
|
||||
auto fg_mgr = mindspore::Manage(func_graphs, true);
|
||||
auto fg_mgr = mindspore::MakeManager();
|
||||
fg_mgr->AddFuncGraph(func_graph, true);
|
||||
MS_EXCEPTION_IF_NULL(fg_mgr);
|
||||
func_graph->set_manager(fg_mgr);
|
||||
(void)mindspore::LiftingClone(func_graph);
|
||||
context_ptr->Refresh();
|
||||
std::string backend_name = context_ptr->backend_policy();
|
||||
std::string target = context_ptr->get_param<std::string>(mindspore::MS_CTX_DEVICE_TARGET);
|
||||
|
@ -295,7 +299,7 @@ STATUS MSFuncGraphCompile(ResMgrHandle res_mgr, GraphHandle graph) {
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS MSFuncGraphRun(ResMgrHandle res_mgr, GraphHandle graph, const TensorHandle inputs[], size_t input_num,
|
||||
STATUS MSFuncGraphRun(ResMgrHandle res_mgr, GraphHandle graph, TensorHandle const inputs[], size_t input_num,
|
||||
TensorHandle outputs[], size_t outputs_num) {
|
||||
if (res_mgr == nullptr || inputs == nullptr || outputs == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [inputs] or [outputs] is nullptr.";
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
Handle GetRawPtr(ResMgrHandle res_mgr, const BasePtr &src_ptr);
|
||||
|
||||
template <typename T>
|
||||
T GetSrcPtr(ResMgrHandle res_mgr, Handle raw_ptr) {
|
||||
T GetSrcPtr(ResMgrHandle res_mgr, ConstHandle raw_ptr) {
|
||||
auto res_mgr_ptr = reinterpret_cast<ResourceManager *>(res_mgr);
|
||||
BasePtr base_ptr = res_mgr_ptr->GetSrcPtr(raw_ptr);
|
||||
if (base_ptr == nullptr) {
|
||||
|
|
|
@ -21,8 +21,15 @@
|
|||
#include "base/base.h"
|
||||
#include "ops/core_ops.h"
|
||||
#include "ir/param_info.h"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/scope.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
#include "backend/common/optimizer/helper.h"
|
||||
|
||||
constexpr size_t firstInIdx = 1;
|
||||
constexpr size_t secondInIdx = 2;
|
||||
constexpr size_t switchInputNum = 3;
|
||||
|
||||
STATUS SetAttrs(ResMgrHandle res_mgr, const PrimitivePtr &prim, char **attr_names, AttrHandle attrs[],
|
||||
size_t attr_num) {
|
||||
AttrMap attr_map{};
|
||||
|
@ -48,7 +55,7 @@ STATUS SetAttrs(ResMgrHandle res_mgr, const PrimitivePtr &prim, char **attr_name
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
NodeHandle MSNewOp(ResMgrHandle res_mgr, GraphHandle graph, const char *op_type, const Handle inputs[],
|
||||
NodeHandle MSNewOp(ResMgrHandle res_mgr, GraphHandle graph, const char *op_type, Handle const inputs[],
|
||||
size_t input_num, char **attr_names, AttrHandle attrs[], size_t attr_num) {
|
||||
if (res_mgr == nullptr || graph == nullptr || op_type == nullptr || inputs == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [op_type] or [inputs] is nullptr.";
|
||||
|
@ -78,11 +85,15 @@ NodeHandle MSNewOp(ResMgrHandle res_mgr, GraphHandle graph, const char *op_type,
|
|||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto input = GetSrcPtr<AnfNodePtr>(res_mgr, inputs[i]);
|
||||
MS_EXCEPTION_IF_NULL(input);
|
||||
if (input->isa<ParameterImpl>() && input->func_graph() != res_fg) {
|
||||
res_fg->AddFreeVariable(input);
|
||||
}
|
||||
ConvertConstScalarInputToTensor(input);
|
||||
cnode_inputs.push_back(input);
|
||||
abs_list.push_back(input->abstract());
|
||||
}
|
||||
cnode = res_fg->NewCNode(cnode_inputs);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (res_mgr_ptr->GetInfer()) {
|
||||
auto out_abs = mindspore::opt::CppInferShapeAndType(prim, abs_list);
|
||||
cnode->set_abstract(out_abs);
|
||||
|
@ -95,7 +106,7 @@ NodeHandle MSNewOp(ResMgrHandle res_mgr, GraphHandle graph, const char *op_type,
|
|||
return GetRawPtr(res_mgr, cnode);
|
||||
}
|
||||
|
||||
NodeHandle MSPackNodesTuple(ResMgrHandle res_mgr, GraphHandle graph, const Handle nodes[], size_t node_num) {
|
||||
NodeHandle MSPackNodesTuple(ResMgrHandle res_mgr, GraphHandle graph, Handle const nodes[], size_t node_num) {
|
||||
if (res_mgr == nullptr || graph == nullptr || nodes == nullptr) {
|
||||
MS_LOG(ERROR) << "Input GraphHandle [res_mgr] or [graph] or [nodes] is nullptr.";
|
||||
return nullptr;
|
||||
|
@ -114,6 +125,7 @@ NodeHandle MSPackNodesTuple(ResMgrHandle res_mgr, GraphHandle graph, const Handl
|
|||
abs_list.push_back(in_node->abstract());
|
||||
}
|
||||
make_tuple_cnode = res_fg->NewCNode(in_nodes);
|
||||
MS_EXCEPTION_IF_NULL(make_tuple_cnode);
|
||||
make_tuple_cnode->set_abstract(std::make_shared<AbstractTupleImpl>(abs_list));
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(ERROR) << "FuncGraph set output failed. Error info: " << e.what();
|
||||
|
@ -122,7 +134,7 @@ NodeHandle MSPackNodesTuple(ResMgrHandle res_mgr, GraphHandle graph, const Handl
|
|||
return GetRawPtr(res_mgr, make_tuple_cnode);
|
||||
}
|
||||
|
||||
NodeHandle MSOpGetSpecOutput(ResMgrHandle res_mgr, GraphHandle graph, const NodeHandle op, size_t i) {
|
||||
NodeHandle MSOpGetSpecOutput(ResMgrHandle res_mgr, GraphHandle graph, ConstNodeHandle op, size_t i) {
|
||||
if (res_mgr == nullptr || graph == nullptr || op == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] is nullptr.";
|
||||
return nullptr;
|
||||
|
@ -148,6 +160,7 @@ NodeHandle MSOpGetSpecOutput(ResMgrHandle res_mgr, GraphHandle graph, const Node
|
|||
auto abs_scalar = std::make_shared<mindspore::abstract::AbstractScalar>(mindspore::SizeToInt(i));
|
||||
idx->set_abstract(abs_scalar);
|
||||
ret_node = res_fg->NewCNode({NewValueNode(mindspore::prim::kPrimTupleGetItem), cnode, idx});
|
||||
MS_EXCEPTION_IF_NULL(ret_node);
|
||||
ret_node->set_abstract(abs->cast<mindspore::abstract::AbstractTuplePtr>()->elements()[i]);
|
||||
} else {
|
||||
if (i >= 1) {
|
||||
|
@ -167,7 +180,229 @@ NodeHandle MSOpGetSpecOutput(ResMgrHandle res_mgr, GraphHandle graph, const Node
|
|||
return GetRawPtr(res_mgr, ret_node);
|
||||
}
|
||||
|
||||
NodeHandle MSOpGetInput(ResMgrHandle res_mgr, const NodeHandle op, size_t i) {
|
||||
CNodePtr BuildSwitchStructure(ResMgrHandle res_mgr, GraphHandle graph, NodeHandle const switch_input[],
|
||||
size_t input_num, bool set_fg_out) {
|
||||
MS_EXCEPTION_IF_NULL(res_mgr);
|
||||
MS_EXCEPTION_IF_NULL(graph);
|
||||
MS_EXCEPTION_IF_NULL(switch_input);
|
||||
MS_EXCEPTION_IF_CHECK_FAIL(input_num == switchInputNum, "Switch's input number must be 3!");
|
||||
NodeHandle switch_op = MSNewOp(res_mgr, graph, "Switch", switch_input, input_num, NULL, NULL, 0);
|
||||
auto src_switch = GetSrcPtr<CNodePtr>(res_mgr, switch_op);
|
||||
MS_EXCEPTION_IF_NULL(src_switch);
|
||||
auto fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
CNodePtr switch_call = fg->NewCNode({src_switch});
|
||||
MS_EXCEPTION_IF_NULL(switch_call);
|
||||
if (set_fg_out) {
|
||||
fg->set_output(switch_call);
|
||||
}
|
||||
auto first_node = GetSrcPtr<ValueNodePtr>(res_mgr, switch_input[firstInIdx]);
|
||||
MS_EXCEPTION_IF_NULL(first_node);
|
||||
auto second_node = GetSrcPtr<ValueNodePtr>(res_mgr, switch_input[secondInIdx]);
|
||||
MS_EXCEPTION_IF_NULL(second_node);
|
||||
// AddFuncGraphCNodeIndex is used to set cnode_index. A funcgraph's cnode_index is a list of pair
|
||||
// with pair-struct is (CNODE, index). The CNODE is in another funcgraph, who uses the funcgraph as its input.
|
||||
// for eg. if fg1's cnode A uses fg2 as A's first input, then fg2's conde_index is (A, 1)
|
||||
if (first_node->isa<ValueNodeImpl>()) {
|
||||
fg->AddValueNode(first_node);
|
||||
if (mindspore::IsValueNode<FuncGraphImpl>(first_node)) {
|
||||
auto used = mindspore::GetValueNode<FuncGraphPtr>(first_node);
|
||||
used->AddFuncGraphCNodeIndex(
|
||||
std::make_shared<mindspore::CNodeIndexPair>(std::make_pair(src_switch, firstInIdx + 1)));
|
||||
(void)fg->AddFuncGraphUsed(used);
|
||||
}
|
||||
}
|
||||
if (second_node->isa<ValueNodeImpl>()) {
|
||||
fg->AddValueNode(second_node);
|
||||
if (mindspore::IsValueNode<FuncGraphImpl>(second_node)) {
|
||||
auto used = mindspore::GetValueNode<FuncGraphPtr>(second_node);
|
||||
used->AddFuncGraphCNodeIndex(
|
||||
std::make_shared<mindspore::CNodeIndexPair>(std::make_pair(src_switch, secondInIdx + 1)));
|
||||
(void)fg->AddFuncGraphUsed(used);
|
||||
}
|
||||
}
|
||||
// Switch-call's abstract is equal to second branch.
|
||||
if (mindspore::IsValueNode<FuncGraphImpl>(second_node)) {
|
||||
auto sub_fg = mindspore::GetValueNode<FuncGraphPtr>(second_node);
|
||||
switch_call->set_abstract(sub_fg->output()->abstract());
|
||||
}
|
||||
return switch_call;
|
||||
}
|
||||
|
||||
NodeHandle MSNewSwitch(ResMgrHandle res_mgr, GraphHandle graph, Handle cond, ConstGraphHandle true_br,
|
||||
ConstGraphHandle false_br) {
|
||||
if (res_mgr == nullptr || graph == nullptr || cond == nullptr || true_br == nullptr || false_br == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [cond] or [true_br] or [false_br] is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
try {
|
||||
auto src_cond = GetSrcPtr<BasePtr>(res_mgr, cond);
|
||||
MS_EXCEPTION_IF_NULL(src_cond);
|
||||
NodeHandle cond_raw_ptr = nullptr;
|
||||
if (src_cond->isa<FuncGraphImpl>()) {
|
||||
auto cond_graph = src_cond->cast<FuncGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(cond_graph);
|
||||
auto cond_node = mindspore::NewValueNode(cond_graph);
|
||||
cond_node->set_abstract(cond_graph->ToAbstract());
|
||||
cond_raw_ptr = GetRawPtr(res_mgr, cond_node);
|
||||
} else {
|
||||
cond_raw_ptr = cond;
|
||||
}
|
||||
auto true_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, true_br);
|
||||
MS_EXCEPTION_IF_NULL(true_fg);
|
||||
auto true_node = mindspore::NewValueNode(true_fg);
|
||||
true_node->set_abstract(true_fg->ToAbstract());
|
||||
NodeHandle true_raw_ptr = GetRawPtr(res_mgr, true_node);
|
||||
|
||||
auto false_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, false_br);
|
||||
MS_EXCEPTION_IF_NULL(false_fg);
|
||||
auto false_node = mindspore::NewValueNode(false_fg);
|
||||
false_node->set_abstract(false_fg->ToAbstract());
|
||||
NodeHandle false_raw_ptr = GetRawPtr(res_mgr, false_node);
|
||||
|
||||
NodeHandle switch_input[] = {cond_raw_ptr, true_raw_ptr, false_raw_ptr};
|
||||
auto switch_call = BuildSwitchStructure(res_mgr, graph, switch_input, switchInputNum, false);
|
||||
MS_EXCEPTION_IF_NULL(switch_call);
|
||||
return GetRawPtr(res_mgr, switch_call);
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(ERROR) << "New Switch node failed. Error info: " << e.what();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void HandleFVInWhileGraph(const FuncGraphPtr &main_fg, const FuncGraphPtr &body_fg, const FuncGraphPtr &after_fg) {
|
||||
std::vector<AnfNodePtr> fv_to_restore{};
|
||||
auto body_fvs = body_fg->free_variables();
|
||||
for (const auto &fv : body_fvs) {
|
||||
auto fv_node = fv.first;
|
||||
MS_EXCEPTION_IF_NULL(fv_node);
|
||||
if (fv_node->func_graph() != main_fg &&
|
||||
std::find(fv_to_restore.begin(), fv_to_restore.end(), fv_node) == fv_to_restore.end()) {
|
||||
fv_to_restore.push_back(fv_node);
|
||||
}
|
||||
}
|
||||
auto after_fvs = after_fg->free_variables();
|
||||
for (const auto &fv : after_fvs) {
|
||||
auto fv_node = fv.first;
|
||||
MS_EXCEPTION_IF_NULL(fv_node);
|
||||
if (fv_node->func_graph() != main_fg &&
|
||||
std::find(fv_to_restore.begin(), fv_to_restore.end(), fv_node) == fv_to_restore.end()) {
|
||||
fv_to_restore.push_back(fv_node);
|
||||
}
|
||||
}
|
||||
|
||||
(void)mindspore::LiftingClone(main_fg);
|
||||
|
||||
auto main_manager = Manage(main_fg);
|
||||
std::vector<AnfNodePtr> new_main_params{};
|
||||
auto main_params = main_fg->parameters();
|
||||
for (const auto &main_param : main_params) {
|
||||
auto src_main_param = main_param->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(src_main_param);
|
||||
auto found_in_fv_list =
|
||||
find_if(fv_to_restore.begin(), fv_to_restore.end(), [&main_param](const AnfNodePtr &fv_param) {
|
||||
return !main_param->ToString().empty() && main_param->ToString() == fv_param->ToString();
|
||||
});
|
||||
if (found_in_fv_list != fv_to_restore.end()) {
|
||||
(void)main_manager->Replace(main_param, *found_in_fv_list);
|
||||
} else if (src_main_param->has_default()) {
|
||||
auto const_input = mindspore::NewValueNode(src_main_param->default_param());
|
||||
const_input->set_abstract(src_main_param->abstract());
|
||||
(void)main_manager->Replace(main_param, const_input);
|
||||
} else {
|
||||
new_main_params.push_back(main_param);
|
||||
}
|
||||
}
|
||||
main_fg->set_parameters(new_main_params);
|
||||
}
|
||||
|
||||
NodeHandle MSNewWhile(ResMgrHandle res_mgr, GraphHandle graph, Handle cond, GraphHandle body_graph,
|
||||
GraphHandle after_graph) {
|
||||
if (res_mgr == nullptr || graph == nullptr || cond == nullptr || body_graph == nullptr || after_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [cond] or [body_graph] or [after_graph] is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
try {
|
||||
auto src_cond = GetSrcPtr<BasePtr>(res_mgr, cond);
|
||||
MS_EXCEPTION_IF_NULL(src_cond);
|
||||
NodeHandle cond_raw_ptr = nullptr;
|
||||
GraphHandle cond_graph = nullptr;
|
||||
FuncGraphPtr src_cond_graph = nullptr;
|
||||
if (src_cond->isa<FuncGraphImpl>()) {
|
||||
cond_graph = cond;
|
||||
src_cond_graph = src_cond->cast<FuncGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(src_cond_graph);
|
||||
auto cond_node = src_cond_graph->output();
|
||||
MS_EXCEPTION_IF_NULL(cond_node);
|
||||
cond_raw_ptr = GetRawPtr(res_mgr, cond_node);
|
||||
} else {
|
||||
cond_graph = MSFuncGraphCreate(res_mgr);
|
||||
MS_EXCEPTION_IF_NULL(cond_graph);
|
||||
src_cond_graph = GetSrcPtr<FuncGraphPtr>(res_mgr, cond_graph);
|
||||
MS_EXCEPTION_IF_NULL(src_cond_graph);
|
||||
if (src_cond->isa<CNodeImpl>()) {
|
||||
auto cond_node = src_cond->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cond_node);
|
||||
auto new_cond = src_cond_graph->NewCNode(cond_node->inputs());
|
||||
MS_EXCEPTION_IF_NULL(new_cond);
|
||||
new_cond->set_abstract(cond_node->abstract());
|
||||
cond_raw_ptr = GetRawPtr(res_mgr, new_cond);
|
||||
} else {
|
||||
cond_raw_ptr = cond;
|
||||
}
|
||||
}
|
||||
|
||||
auto body_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, body_graph);
|
||||
MS_EXCEPTION_IF_NULL(body_fg);
|
||||
auto body_node = mindspore::NewValueNode(body_fg);
|
||||
body_node->set_abstract(body_fg->ToAbstract());
|
||||
NodeHandle body_raw_ptr = GetRawPtr(res_mgr, body_node);
|
||||
|
||||
auto after_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, after_graph);
|
||||
MS_EXCEPTION_IF_NULL(after_fg);
|
||||
auto after_node = mindspore::NewValueNode(after_fg);
|
||||
after_node->set_abstract(after_fg->ToAbstract());
|
||||
NodeHandle after_raw_ptr = GetRawPtr(res_mgr, after_node);
|
||||
|
||||
NodeHandle switch_input[] = {cond_raw_ptr, body_raw_ptr, after_raw_ptr};
|
||||
(void)BuildSwitchStructure(res_mgr, cond_graph, switch_input, switchInputNum, true);
|
||||
|
||||
// handle main graph call
|
||||
auto main_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
|
||||
NodeHandle main_func_call = MSNewFuncCallNode(res_mgr, graph, cond_graph, nullptr, 0);
|
||||
auto src_call = GetSrcPtr<AnfNodePtr>(res_mgr, main_func_call);
|
||||
main_fg->set_output(src_call);
|
||||
|
||||
// handle free parameters in while graphs
|
||||
HandleFVInWhileGraph(main_fg, body_fg, after_fg);
|
||||
|
||||
// handle multi outputs in body graph
|
||||
auto sub_graph_node = mindspore::NewValueNode(src_cond_graph);
|
||||
sub_graph_node->set_abstract(src_cond_graph->ToAbstract());
|
||||
std::vector<AnfNodePtr> sub_input_nodes{sub_graph_node};
|
||||
auto body_out_node = body_fg->output();
|
||||
MS_EXCEPTION_IF_NULL(body_out_node);
|
||||
if (IsPrimitiveCNode(body_out_node, mindspore::prim::kPrimMakeTuple)) {
|
||||
auto body_out_cnode = body_out_node->cast<CNodePtr>();
|
||||
for (size_t i = 1; i < body_out_cnode->size(); i++) {
|
||||
sub_input_nodes.push_back(body_out_cnode->input(i));
|
||||
}
|
||||
} else {
|
||||
sub_input_nodes.push_back(body_out_node);
|
||||
}
|
||||
auto body_func_call = body_fg->NewCNode(sub_input_nodes);
|
||||
MS_EXCEPTION_IF_NULL(src_cond_graph->output());
|
||||
body_func_call->set_abstract(src_cond_graph->output()->abstract());
|
||||
MS_EXCEPTION_IF_NULL(body_func_call);
|
||||
body_fg->set_output(body_func_call);
|
||||
return main_func_call;
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(ERROR) << "New While node failed. Error info: " << e.what();
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
NodeHandle MSOpGetInput(ResMgrHandle res_mgr, ConstNodeHandle op, size_t i) {
|
||||
if (res_mgr == nullptr || op == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] is nullptr.";
|
||||
return nullptr;
|
||||
|
@ -188,7 +423,7 @@ NodeHandle MSOpGetInput(ResMgrHandle res_mgr, const NodeHandle op, size_t i) {
|
|||
return GetRawPtr(res_mgr, anf_node);
|
||||
}
|
||||
|
||||
size_t MSOpGetInputsNum(ResMgrHandle res_mgr, const NodeHandle op, STATUS *error) {
|
||||
size_t MSOpGetInputsNum(ResMgrHandle res_mgr, ConstNodeHandle op, STATUS *error) {
|
||||
if (error == nullptr) {
|
||||
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
|
||||
return 0;
|
||||
|
@ -212,7 +447,7 @@ size_t MSOpGetInputsNum(ResMgrHandle res_mgr, const NodeHandle op, STATUS *error
|
|||
return input_num;
|
||||
}
|
||||
|
||||
STATUS MSOpGetInputs(ResMgrHandle res_mgr, const NodeHandle op, NodeHandle inputs[], size_t input_num) {
|
||||
STATUS MSOpGetInputs(ResMgrHandle res_mgr, ConstNodeHandle op, NodeHandle inputs[], size_t input_num) {
|
||||
if (res_mgr == nullptr || op == nullptr || inputs == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [inputs] is nullptr.";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -236,10 +471,10 @@ STATUS MSOpGetInputs(ResMgrHandle res_mgr, const NodeHandle op, NodeHandle input
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
NodeHandle MSNewSubGraphNode(ResMgrHandle res_mgr, GraphHandle graph, GraphHandle sub_graph, const Handle inputs[],
|
||||
NodeHandle MSNewFuncCallNode(ResMgrHandle res_mgr, GraphHandle graph, ConstGraphHandle sub_graph, Handle const inputs[],
|
||||
size_t input_num) {
|
||||
if (res_mgr == nullptr || graph == nullptr || sub_graph == nullptr || inputs == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [sub_graph] or [inputs] is nullptr.";
|
||||
if (res_mgr == nullptr || graph == nullptr || sub_graph == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [sub_graph] is nullptr.";
|
||||
return nullptr;
|
||||
}
|
||||
CNodePtr cnode = nullptr;
|
||||
|
@ -248,20 +483,22 @@ NodeHandle MSNewSubGraphNode(ResMgrHandle res_mgr, GraphHandle graph, GraphHandl
|
|||
MS_EXCEPTION_IF_NULL(res_fg);
|
||||
auto res_sub_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, sub_graph);
|
||||
MS_EXCEPTION_IF_NULL(res_sub_fg);
|
||||
auto sub_fg_node = mindspore::NewValueNode(res_sub_fg);
|
||||
std::vector<AnfNodePtr> cnode_inputs{};
|
||||
cnode_inputs.push_back(sub_fg_node);
|
||||
auto sub_node = mindspore::NewValueNode(res_sub_fg);
|
||||
sub_node->set_abstract(res_sub_fg->ToAbstract());
|
||||
std::vector<AnfNodePtr> cnode_inputs{sub_node};
|
||||
for (size_t i = 0; i < input_num; ++i) {
|
||||
auto cnode_input = GetSrcPtr<AnfNodePtr>(res_mgr, inputs[i]);
|
||||
MS_EXCEPTION_IF_NULL(cnode_input);
|
||||
cnode_inputs.push_back(cnode_input);
|
||||
}
|
||||
cnode = res_fg->NewCNode(cnode_inputs);
|
||||
MS_EXCEPTION_IF_NULL(res_sub_fg->output());
|
||||
cnode->set_abstract(res_sub_fg->output()->abstract());
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(ERROR) << "FuncGraph create SubGraph node failed. Error info: " << e.what();
|
||||
return nullptr;
|
||||
}
|
||||
MS_LOG(INFO) << "Add subgraph node";
|
||||
MS_LOG(INFO) << "Add function call node";
|
||||
return GetRawPtr(res_mgr, cnode);
|
||||
}
|
||||
|
||||
|
@ -309,7 +546,7 @@ NodeHandle MSNewTensorVariable(ResMgrHandle res_mgr, GraphHandle graph, void *da
|
|||
return GetRawPtr(res_mgr, param);
|
||||
}
|
||||
|
||||
NodeHandle MSNewTensorVariableFromTensor(ResMgrHandle res_mgr, GraphHandle graph, TensorHandle tensor) {
|
||||
NodeHandle MSNewTensorVariableFromTensor(ResMgrHandle res_mgr, GraphHandle graph, ConstTensorHandle tensor) {
|
||||
if (res_mgr == nullptr || graph == nullptr || tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [tensor] is nullptr.";
|
||||
return nullptr;
|
||||
|
@ -330,7 +567,7 @@ NodeHandle MSNewTensorVariableFromTensor(ResMgrHandle res_mgr, GraphHandle graph
|
|||
return GetRawPtr(res_mgr, param);
|
||||
}
|
||||
|
||||
size_t MSTensorVariableGetDataSize(ResMgrHandle res_mgr, NodeHandle node, STATUS *error) {
|
||||
size_t MSTensorVariableGetDataSize(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
|
||||
if (error == nullptr) {
|
||||
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
|
||||
return 0;
|
||||
|
@ -357,7 +594,7 @@ size_t MSTensorVariableGetDataSize(ResMgrHandle res_mgr, NodeHandle node, STATUS
|
|||
}
|
||||
}
|
||||
|
||||
void *MSTensorVariableGetData(ResMgrHandle res_mgr, NodeHandle node) {
|
||||
void *MSTensorVariableGetData(ResMgrHandle res_mgr, ConstNodeHandle node) {
|
||||
if (res_mgr == nullptr || node == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] is nullptr.";
|
||||
return nullptr;
|
||||
|
@ -415,7 +652,7 @@ NodeHandle MSNewTensorConstantFromTensor(ResMgrHandle res_mgr, TensorHandle tens
|
|||
return GetRawPtr(res_mgr, value_node);
|
||||
}
|
||||
|
||||
size_t MSTensorConstantGetDataSize(ResMgrHandle res_mgr, NodeHandle node, STATUS *error) {
|
||||
size_t MSTensorConstantGetDataSize(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
|
||||
if (error == nullptr) {
|
||||
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
|
||||
return 0;
|
||||
|
@ -442,7 +679,7 @@ size_t MSTensorConstantGetDataSize(ResMgrHandle res_mgr, NodeHandle node, STATUS
|
|||
}
|
||||
}
|
||||
|
||||
void *MSTensorConstantGetData(ResMgrHandle res_mgr, NodeHandle node) {
|
||||
void *MSTensorConstantGetData(ResMgrHandle res_mgr, ConstNodeHandle node) {
|
||||
if (res_mgr == nullptr || node == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] is nullptr.";
|
||||
return nullptr;
|
||||
|
@ -548,7 +785,7 @@ NodeHandle MSNewTypeConstant(ResMgrHandle res_mgr, TypeId type) {
|
|||
return GetRawPtr(res_mgr, value_node);
|
||||
}
|
||||
|
||||
int MSScalarConstantGetValueInt32(ResMgrHandle res_mgr, const NodeHandle node, STATUS *error) {
|
||||
int MSScalarConstantGetValueInt32(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
|
||||
MS_LOG(INFO) << "Get Int32 Scalar Value!";
|
||||
if (error == nullptr) {
|
||||
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
|
||||
|
@ -585,7 +822,7 @@ int MSScalarConstantGetValueInt32(ResMgrHandle res_mgr, const NodeHandle node, S
|
|||
return ret_val;
|
||||
}
|
||||
|
||||
float MSScalarConstantGetValueFloat32(ResMgrHandle res_mgr, const NodeHandle node, STATUS *error) {
|
||||
float MSScalarConstantGetValueFloat32(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
|
||||
MS_LOG(INFO) << "Get Float32 Scalar Value!";
|
||||
if (error == nullptr) {
|
||||
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
|
||||
|
@ -622,7 +859,7 @@ float MSScalarConstantGetValueFloat32(ResMgrHandle res_mgr, const NodeHandle nod
|
|||
return ret_val;
|
||||
}
|
||||
|
||||
bool MSScalarConstantGetValueBool(ResMgrHandle res_mgr, const NodeHandle node, STATUS *error) {
|
||||
bool MSScalarConstantGetValueBool(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
|
||||
MS_LOG(INFO) << "Get Bool Scalar Value!";
|
||||
if (error == nullptr) {
|
||||
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
|
||||
|
@ -659,7 +896,7 @@ bool MSScalarConstantGetValueBool(ResMgrHandle res_mgr, const NodeHandle node, S
|
|||
return ret_val;
|
||||
}
|
||||
|
||||
int64_t MSScalarConstantGetValueInt64(ResMgrHandle res_mgr, const NodeHandle node, STATUS *error) {
|
||||
int64_t MSScalarConstantGetValueInt64(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
|
||||
MS_LOG(INFO) << "Get Int64 Scalar Value!";
|
||||
if (error == nullptr) {
|
||||
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
|
||||
|
@ -696,7 +933,7 @@ int64_t MSScalarConstantGetValueInt64(ResMgrHandle res_mgr, const NodeHandle nod
|
|||
return ret_val;
|
||||
}
|
||||
|
||||
STATUS MSStringConstantGetValue(ResMgrHandle res_mgr, const NodeHandle node, char str_buf[], size_t str_len) {
|
||||
STATUS MSStringConstantGetValue(ResMgrHandle res_mgr, ConstNodeHandle node, char str_buf[], size_t str_len) {
|
||||
MS_LOG(INFO) << "Get String Constant Value!";
|
||||
if (res_mgr == nullptr || node == nullptr || str_buf == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] or [str_buf] is nullptr.";
|
||||
|
@ -721,7 +958,7 @@ STATUS MSStringConstantGetValue(ResMgrHandle res_mgr, const NodeHandle node, cha
|
|||
}
|
||||
}
|
||||
|
||||
size_t MSTupleConstantGetSize(ResMgrHandle res_mgr, const NodeHandle node, STATUS *error) {
|
||||
size_t MSTupleConstantGetSize(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
|
||||
MS_LOG(INFO) << "Get Tuple Constant size!";
|
||||
if (error == nullptr) {
|
||||
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
|
||||
|
@ -748,7 +985,7 @@ size_t MSTupleConstantGetSize(ResMgrHandle res_mgr, const NodeHandle node, STATU
|
|||
}
|
||||
}
|
||||
|
||||
STATUS MSTupleConstantGetValueInt64(ResMgrHandle res_mgr, const NodeHandle node, int64_t vec[], size_t size) {
|
||||
STATUS MSTupleConstantGetValueInt64(ResMgrHandle res_mgr, ConstNodeHandle node, int64_t vec[], size_t size) {
|
||||
MS_LOG(INFO) << "Get Tuple Constant Value!";
|
||||
if (res_mgr == nullptr || node == nullptr || vec == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] or [vec] is nullptr.";
|
||||
|
@ -776,7 +1013,7 @@ STATUS MSTupleConstantGetValueInt64(ResMgrHandle res_mgr, const NodeHandle node,
|
|||
}
|
||||
}
|
||||
|
||||
TypeId MSTypeConstantGetValue(ResMgrHandle res_mgr, const NodeHandle node, STATUS *error) {
|
||||
TypeId MSTypeConstantGetValue(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
|
||||
MS_LOG(INFO) << "Get Type Constant Value!";
|
||||
if (error == nullptr) {
|
||||
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
|
||||
|
@ -803,7 +1040,7 @@ TypeId MSTypeConstantGetValue(ResMgrHandle res_mgr, const NodeHandle node, STATU
|
|||
}
|
||||
}
|
||||
|
||||
STATUS MSOpSetName(ResMgrHandle res_mgr, const NodeHandle node, const char *name) {
|
||||
STATUS MSOpSetName(ResMgrHandle res_mgr, NodeHandle node, const char *name) {
|
||||
if (res_mgr == nullptr || node == nullptr || name == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] or [name] is nullptr.";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -817,7 +1054,7 @@ STATUS MSOpSetName(ResMgrHandle res_mgr, const NodeHandle node, const char *name
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS MSNodeGetName(ResMgrHandle res_mgr, const NodeHandle node, char str_buf[], size_t str_len) {
|
||||
STATUS MSNodeGetName(ResMgrHandle res_mgr, ConstNodeHandle node, char str_buf[], size_t str_len) {
|
||||
if (res_mgr == nullptr || node == nullptr || str_buf == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] or [str_buf] is nullptr.";
|
||||
return RET_NULL_PTR;
|
||||
|
|
|
@ -34,6 +34,7 @@ class ResourceManager {
|
|||
context_ = mindspore::MsContext::GetInstance();
|
||||
org_policy_ = context_->backend_policy();
|
||||
context_->set_backend_policy("ms");
|
||||
context_->set_param<int>(mindspore::MS_CTX_EXECUTION_MODE, mindspore::kGraphMode);
|
||||
}
|
||||
|
||||
~ResourceManager() { context_->set_backend_policy(org_policy_); }
|
||||
|
@ -60,7 +61,7 @@ class ResourceManager {
|
|||
(void)ptr_res_pool_.insert(std::make_pair(reinterpret_cast<Handle>(src_ptr.get()), src_ptr));
|
||||
}
|
||||
|
||||
BasePtr GetSrcPtr(Handle ptr) {
|
||||
BasePtr GetSrcPtr(ConstHandle ptr) {
|
||||
auto iter = ptr_res_pool_.find(ptr);
|
||||
if (iter != ptr_res_pool_.end()) {
|
||||
return iter->second;
|
||||
|
@ -70,7 +71,7 @@ class ResourceManager {
|
|||
}
|
||||
}
|
||||
|
||||
void ReleaseSrcPtr(Handle ptr) {
|
||||
void ReleaseSrcPtr(ConstHandle ptr) {
|
||||
auto iter = ptr_res_pool_.find(ptr);
|
||||
if (iter != ptr_res_pool_.end()) {
|
||||
(void)ptr_res_pool_.erase(iter);
|
||||
|
@ -78,10 +79,10 @@ class ResourceManager {
|
|||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<Handle, BasePtr> ptr_res_pool_;
|
||||
std::unordered_map<ConstHandle, BasePtr> ptr_res_pool_;
|
||||
mindspore::HashMap<std::string, mindspore::Any> results_{};
|
||||
std::shared_ptr<mindspore::compile::Backend> backend_ = nullptr;
|
||||
std::shared_ptr<mindspore::MsContext> context_;
|
||||
std::shared_ptr<mindspore::MsContext> context_ = nullptr;
|
||||
std::string org_policy_;
|
||||
bool auto_infer_;
|
||||
};
|
||||
|
|
|
@ -113,7 +113,7 @@ TensorHandle MSNewTensorWithSrcType(ResMgrHandle res_mgr, void *data, const int6
|
|||
return GetRawPtr(res_mgr, tensor);
|
||||
}
|
||||
|
||||
void *MSTensorGetData(ResMgrHandle res_mgr, const TensorHandle tensor) {
|
||||
void *MSTensorGetData(ResMgrHandle res_mgr, ConstTensorHandle tensor) {
|
||||
if (res_mgr == nullptr || tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [tensor] is nullptr.";
|
||||
return nullptr;
|
||||
|
@ -126,7 +126,7 @@ void *MSTensorGetData(ResMgrHandle res_mgr, const TensorHandle tensor) {
|
|||
return src_tensor->data_c();
|
||||
}
|
||||
|
||||
STATUS MSTensorSetDataType(ResMgrHandle res_mgr, const TensorHandle tensor, TypeId type) {
|
||||
STATUS MSTensorSetDataType(ResMgrHandle res_mgr, TensorHandle tensor, TypeId type) {
|
||||
if (res_mgr == nullptr || tensor == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [tensor] is nullptr.";
|
||||
return RET_ERROR;
|
||||
|
@ -140,7 +140,7 @@ STATUS MSTensorSetDataType(ResMgrHandle res_mgr, const TensorHandle tensor, Type
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
TypeId MSTensorGetDataType(ResMgrHandle res_mgr, const TensorHandle tensor, STATUS *error) {
|
||||
TypeId MSTensorGetDataType(ResMgrHandle res_mgr, ConstTensorHandle tensor, STATUS *error) {
|
||||
if (error == nullptr) {
|
||||
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
|
||||
return (enum TypeId)0;
|
||||
|
@ -160,7 +160,7 @@ TypeId MSTensorGetDataType(ResMgrHandle res_mgr, const TensorHandle tensor, STAT
|
|||
return (enum TypeId)(src_tensor->data_type_c());
|
||||
}
|
||||
|
||||
size_t MSTensorGetDataSize(ResMgrHandle res_mgr, const TensorHandle tensor, STATUS *error) {
|
||||
size_t MSTensorGetDataSize(ResMgrHandle res_mgr, ConstTensorHandle tensor, STATUS *error) {
|
||||
if (error == nullptr) {
|
||||
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
|
||||
return 0;
|
||||
|
@ -181,7 +181,7 @@ size_t MSTensorGetDataSize(ResMgrHandle res_mgr, const TensorHandle tensor, STAT
|
|||
return size;
|
||||
}
|
||||
|
||||
size_t MSTensorGetElementNum(ResMgrHandle res_mgr, const TensorHandle tensor, STATUS *error) {
|
||||
size_t MSTensorGetElementNum(ResMgrHandle res_mgr, ConstTensorHandle tensor, STATUS *error) {
|
||||
if (error == nullptr) {
|
||||
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
|
||||
return 0;
|
||||
|
@ -202,7 +202,7 @@ size_t MSTensorGetElementNum(ResMgrHandle res_mgr, const TensorHandle tensor, ST
|
|||
return ele_num;
|
||||
}
|
||||
|
||||
size_t MSTensorGetDimension(ResMgrHandle res_mgr, const TensorHandle tensor, STATUS *error) {
|
||||
size_t MSTensorGetDimension(ResMgrHandle res_mgr, ConstTensorHandle tensor, STATUS *error) {
|
||||
if (error == nullptr) {
|
||||
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
|
||||
return 0;
|
||||
|
@ -223,7 +223,7 @@ size_t MSTensorGetDimension(ResMgrHandle res_mgr, const TensorHandle tensor, STA
|
|||
return dim;
|
||||
}
|
||||
|
||||
STATUS MSTensorSetShape(ResMgrHandle res_mgr, const TensorHandle tensor, int64_t shape[], size_t dim) {
|
||||
STATUS MSTensorSetShape(ResMgrHandle res_mgr, TensorHandle tensor, const int64_t shape[], size_t dim) {
|
||||
if (res_mgr == nullptr || tensor == nullptr || shape == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [tensor] or [shape] is nullptr.";
|
||||
return RET_NULL_PTR;
|
||||
|
@ -243,7 +243,7 @@ STATUS MSTensorSetShape(ResMgrHandle res_mgr, const TensorHandle tensor, int64_t
|
|||
return RET_OK;
|
||||
}
|
||||
|
||||
STATUS MSTensorGetShape(ResMgrHandle res_mgr, const TensorHandle tensor, int64_t shape[], size_t dim) {
|
||||
STATUS MSTensorGetShape(ResMgrHandle res_mgr, ConstTensorHandle tensor, int64_t shape[], size_t dim) {
|
||||
if (res_mgr == nullptr || tensor == nullptr || shape == nullptr) {
|
||||
MS_LOG(ERROR) << "Input Handle [res_mgr] or [tensor] or [shape] is nullptr.";
|
||||
return RET_NULL_PTR;
|
||||
|
|
Loading…
Reference in New Issue