!48484 [MS] c api add control flow api

Merge pull request !48484 from XianglongZeng/c_api_pr
This commit is contained in:
i-robot 2023-02-10 08:10:30 +00:00 committed by Gitee
commit 6e010ea825
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
14 changed files with 384 additions and 106 deletions

View File

@ -26,9 +26,17 @@ typedef void *TensorHandle;
typedef void *NodeHandle;
typedef void *AttrHandle;
typedef void *GraphHandle;
typedef void *FuncGraphManagerHandle;
typedef void *FuncGraphMgrHandle;
typedef void *ResMgrHandle;
typedef const void *ConstHandle;
typedef const void *ConstTensorHandle;
typedef const void *ConstNodeHandle;
typedef const void *ConstAttrHandle;
typedef const void *ConstGraphHandle;
typedef const void *ConstFuncGraphMgrHandle;
typedef const void *ConstResMgrHandle;
#ifdef __cplusplus
}
#endif

View File

@ -35,7 +35,7 @@ extern "C" {
/// \param[in] input_node The input node which contains the Abstract.
///
/// \return Error code indicates whether the function executed successfully.
MIND_C_API STATUS MSAssignAbstract(ResMgrHandle res_mgr, NodeHandle cur_node, NodeHandle input_node);
MIND_C_API STATUS MSAssignAbstract(ResMgrHandle res_mgr, NodeHandle cur_node, ConstNodeHandle input_node);
/// \brief Set Abstract to the node with type and shape.
///

View File

@ -134,7 +134,8 @@ MIND_C_API STATUS MSOpSetAttrString(ResMgrHandle res_mgr, NodeHandle op, const c
/// \param[in] error Error code indicates whether the function executed successfully.
///
/// \return Attribute value
MIND_C_API int64_t MSOpGetScalarAttrInt64(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, STATUS *error);
MIND_C_API int64_t MSOpGetScalarAttrInt64(ResMgrHandle res_mgr, ConstNodeHandle op, const char *attr_name,
STATUS *error);
/// \brief Get the attribute of the target node with the given attribute name.
///
@ -145,8 +146,8 @@ MIND_C_API int64_t MSOpGetScalarAttrInt64(ResMgrHandle res_mgr, NodeHandle op, c
/// \param[in] value_num Size of the given array.
///
/// \return Error code indicates whether the function executed successfully.
MIND_C_API STATUS MSOpGetAttrArrayInt64(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, int64_t values[],
size_t value_num);
MIND_C_API STATUS MSOpGetAttrArrayInt64(ResMgrHandle res_mgr, ConstNodeHandle op, const char *attr_name,
int64_t values[], size_t value_num);
/// \brief Create new Int64 attribute scalar value.
///

View File

@ -43,7 +43,7 @@ MIND_C_API GraphHandle MSFuncGraphCreate(ResMgrHandle res_mgr);
/// \param[in] i Index of the input node.
///
/// \return The created function graph.
MIND_C_API NodeHandle MSFuncGraphGetInput(ResMgrHandle res_mgr, const GraphHandle graph, size_t i);
MIND_C_API NodeHandle MSFuncGraphGetInput(ResMgrHandle res_mgr, ConstGraphHandle graph, size_t i);
/// \brief Get the inputs number of the function graph.
///
@ -52,7 +52,7 @@ MIND_C_API NodeHandle MSFuncGraphGetInput(ResMgrHandle res_mgr, const GraphHandl
/// \param[in] error Records error code that indicate whether the functions executed successfully.
///
/// \return the inputs number of the function graph.
MIND_C_API size_t MSFuncGraphGetInputNum(ResMgrHandle res_mgr, const GraphHandle graph, STATUS *error);
MIND_C_API size_t MSFuncGraphGetInputNum(ResMgrHandle res_mgr, ConstGraphHandle graph, STATUS *error);
/// \brief Get all inputs of the function graph.
///
@ -62,7 +62,7 @@ MIND_C_API size_t MSFuncGraphGetInputNum(ResMgrHandle res_mgr, const GraphHandle
/// \param[in] input_num The length of the array.
///
/// \return Error code that indicate whether the functions executed successfully.
MIND_C_API STATUS MSFuncGraphGetInputs(ResMgrHandle res_mgr, const GraphHandle graph, NodeHandle inputs[],
MIND_C_API STATUS MSFuncGraphGetInputs(ResMgrHandle res_mgr, ConstGraphHandle graph, NodeHandle inputs[],
size_t input_num);
/// \brief Set the output node.
@ -73,7 +73,7 @@ MIND_C_API STATUS MSFuncGraphGetInputs(ResMgrHandle res_mgr, const GraphHandle g
/// \param[in] force_new_ret If true, a new return node is always created.
///
/// \return Error code that indicate whether the functions executed successfully.
MIND_C_API STATUS MSFuncGraphSetOutput(ResMgrHandle res_mgr, GraphHandle graph, const NodeHandle op_node,
MIND_C_API STATUS MSFuncGraphSetOutput(ResMgrHandle res_mgr, GraphHandle graph, ConstNodeHandle op_node,
bool force_new_ret);
/// \brief Set the output node.
@ -84,7 +84,7 @@ MIND_C_API STATUS MSFuncGraphSetOutput(ResMgrHandle res_mgr, GraphHandle graph,
/// \param[in] force_new_ret If true, a new return node is always created.
///
/// \return Error code that indicate whether the functions executed successfully.
MIND_C_API STATUS MSFuncGraphSetOutputs(ResMgrHandle res_mgr, GraphHandle graph, const Handle outputs[],
MIND_C_API STATUS MSFuncGraphSetOutputs(ResMgrHandle res_mgr, GraphHandle graph, Handle const outputs[],
size_t output_num, bool force_new_ret);
/// \brief Get the output node according to the index.
@ -94,7 +94,7 @@ MIND_C_API STATUS MSFuncGraphSetOutputs(ResMgrHandle res_mgr, GraphHandle graph,
/// \param[in] i The index to get the output. If there is only one output for graph, the i should be 0;
///
/// \return The output node, nullptr if output not set.
MIND_C_API NodeHandle MSFuncGraphGetOutput(ResMgrHandle res_mgr, const GraphHandle graph, size_t i);
MIND_C_API NodeHandle MSFuncGraphGetOutput(ResMgrHandle res_mgr, ConstGraphHandle graph, size_t i);
/// \brief Get the outputs number of the function graph.
///
@ -103,7 +103,7 @@ MIND_C_API NodeHandle MSFuncGraphGetOutput(ResMgrHandle res_mgr, const GraphHand
/// \param[in] error Records error code that indicate whether the functions executed successfully.
///
/// \return the outputs number of the function graph.
MIND_C_API size_t MSFuncGraphGetOutputNum(ResMgrHandle res_mgr, GraphHandle graph, STATUS *error);
MIND_C_API size_t MSFuncGraphGetOutputNum(ResMgrHandle res_mgr, ConstGraphHandle graph, STATUS *error);
/// \brief Get all outputs of the function graph.
///
@ -113,7 +113,7 @@ MIND_C_API size_t MSFuncGraphGetOutputNum(ResMgrHandle res_mgr, GraphHandle grap
/// \param[in] output_num The length of the array.
///
/// \return Error code that indicate whether the functions executed successfully.
MIND_C_API STATUS MSFuncGraphGetOutputs(ResMgrHandle res_mgr, const GraphHandle graph, NodeHandle outputs[],
MIND_C_API STATUS MSFuncGraphGetOutputs(ResMgrHandle res_mgr, ConstGraphHandle graph, NodeHandle outputs[],
size_t output_num);
/// \brief Replace a node in a function graph.
@ -124,8 +124,8 @@ MIND_C_API STATUS MSFuncGraphGetOutputs(ResMgrHandle res_mgr, const GraphHandle
/// \param[in] new_node The replace node.
///
/// \return Error code that indicate whether the functions executed successfully.
MIND_C_API STATUS MSFuncGraphReplace(ResMgrHandle res_mgr, GraphHandle graph, const NodeHandle old_node,
const NodeHandle new_node);
MIND_C_API STATUS MSFuncGraphReplace(ResMgrHandle res_mgr, GraphHandle graph, ConstNodeHandle old_node,
ConstNodeHandle new_node);
/// \brief Compile the function graph.
///
@ -145,7 +145,7 @@ MIND_C_API STATUS MSFuncGraphCompile(ResMgrHandle res_mgr, GraphHandle graph);
/// \param[in] outputs_num The output size.
///
/// \return Error code that indicate whether the function graph executed successfully.
MIND_C_API STATUS MSFuncGraphRun(ResMgrHandle res_mgr, GraphHandle graph, const TensorHandle inputs[], size_t input_num,
MIND_C_API STATUS MSFuncGraphRun(ResMgrHandle res_mgr, GraphHandle graph, TensorHandle const inputs[], size_t input_num,
TensorHandle outputs[], size_t outputs_num);
#ifdef __cplusplus
}

View File

@ -42,7 +42,7 @@ extern "C" {
/// \param[in] attr_num The number of attributes.
///
/// \return The created Operator node handle
MIND_C_API NodeHandle MSNewOp(ResMgrHandle res_mgr, GraphHandle graph, const char *op_type, const Handle inputs[],
MIND_C_API NodeHandle MSNewOp(ResMgrHandle res_mgr, GraphHandle graph, const char *op_type, Handle const inputs[],
size_t input_num, char **attr_names, AttrHandle attrs[], size_t attr_num);
/// \brief Pack nodes into a Tuple node.
@ -53,9 +53,9 @@ MIND_C_API NodeHandle MSNewOp(ResMgrHandle res_mgr, GraphHandle graph, const cha
/// \param[in] node_num The number of nodes in the array.
///
/// \return The created Tuple node handle.
MIND_C_API NodeHandle MSPackNodesTuple(ResMgrHandle res_mgr, GraphHandle graph, const Handle nodes[], size_t node_num);
MIND_C_API NodeHandle MSPackNodesTuple(ResMgrHandle res_mgr, GraphHandle graph, Handle const nodes[], size_t node_num);
/// \brief Get specified output branch from as multi-output Operator.
/// \brief Get specified output branch from a multi-output Operator.
///
/// \param[in] res_mgr Resource manager that saves allocated instance resources.
/// \param[in] graph The given function graph pointer handle.
@ -63,7 +63,31 @@ MIND_C_API NodeHandle MSPackNodesTuple(ResMgrHandle res_mgr, GraphHandle graph,
/// \param[in] i The index of the output branch.
///
/// \return The obtained output node.
MIND_C_API NodeHandle MSOpGetSpecOutput(ResMgrHandle res_mgr, GraphHandle graph, const NodeHandle op, size_t i);
MIND_C_API NodeHandle MSOpGetSpecOutput(ResMgrHandle res_mgr, GraphHandle graph, ConstNodeHandle op, size_t i);
/// \brief Create a Switch operator for control-flow scene.
///
/// \param[in] res_mgr Resource manager that saves allocated instance resources.
/// \param[in] graph The given function graph pointer handle.
/// \param[in] cond The condition of Switch which can be an Operator or a Subgraph.
/// \param[in] true_br The true branch of Switch which must be a Subgraph.
/// \param[in] false_br The false branch of Switch which must be a Subgraph.
///
/// \return The created Switch operator node.
MIND_C_API NodeHandle MSNewSwitch(ResMgrHandle res_mgr, GraphHandle graph, Handle cond, ConstGraphHandle true_br,
ConstGraphHandle false_br);
/// \brief Create a While operator for control-flow scene.
///
/// \param[in] res_mgr Resource manager that saves allocated instance resources.
/// \param[in] graph The given function graph pointer handle.
/// \param[in] cond The condition of While which can be an Operator or a Subgraph.
/// \param[in] body_graph The loop body of While which must be a Subgraph.
/// \param[in] after_graph The graph after stepping out the While which must be a Subgraph.
///
/// \return The created While operator node.
MIND_C_API NodeHandle MSNewWhile(ResMgrHandle res_mgr, GraphHandle graph, Handle cond, GraphHandle body_graph,
GraphHandle after_graph);
/// \brief Get specified input node of Operator.
///
@ -72,7 +96,7 @@ MIND_C_API NodeHandle MSOpGetSpecOutput(ResMgrHandle res_mgr, GraphHandle graph,
/// \param[in] i The index of the input.
///
/// \return The obtained input node handle.
MIND_C_API NodeHandle MSOpGetInput(ResMgrHandle res_mgr, const NodeHandle op, size_t i);
MIND_C_API NodeHandle MSOpGetInput(ResMgrHandle res_mgr, ConstNodeHandle op, size_t i);
/// \brief Get the input nodes number of Operator.
///
@ -81,7 +105,7 @@ MIND_C_API NodeHandle MSOpGetInput(ResMgrHandle res_mgr, const NodeHandle op, si
/// \param[in] error Records error code that indicate whether the functions executed successfully.
///
/// \return The input nodes number.
MIND_C_API size_t MSOpGetInputsNum(ResMgrHandle res_mgr, const NodeHandle op, STATUS *error);
MIND_C_API size_t MSOpGetInputsNum(ResMgrHandle res_mgr, ConstNodeHandle op, STATUS *error);
/// \brief Get all input nodes of the Operator.
///
@ -91,7 +115,7 @@ MIND_C_API size_t MSOpGetInputsNum(ResMgrHandle res_mgr, const NodeHandle op, ST
/// \param[in] input_num The size of the input array.
///
/// \return Error code that indicate whether the functions executed successfully.
MIND_C_API STATUS MSOpGetInputs(ResMgrHandle res_mgr, const NodeHandle op, NodeHandle inputs[], size_t input_num);
MIND_C_API STATUS MSOpGetInputs(ResMgrHandle res_mgr, ConstNodeHandle op, NodeHandle inputs[], size_t input_num);
/// \brief Create a subgraph node.
///
@ -102,8 +126,8 @@ MIND_C_API STATUS MSOpGetInputs(ResMgrHandle res_mgr, const NodeHandle op, NodeH
/// \param[in] input_num The number of the input array.
///
/// \return The created subgraph node handle.
MIND_C_API NodeHandle MSNewSubGraphNode(ResMgrHandle res_mgr, GraphHandle graph, GraphHandle sub_graph,
const Handle inputs[], size_t input_num);
MIND_C_API NodeHandle MSNewFuncCallNode(ResMgrHandle res_mgr, GraphHandle graph, ConstGraphHandle sub_graph,
Handle const inputs[], size_t input_num);
/// \brief Create a Placeholder node, which is usually the input of graph without data.
///
@ -138,7 +162,7 @@ MIND_C_API NodeHandle MSNewTensorVariable(ResMgrHandle res_mgr, GraphHandle grap
/// \param[in] tensor The given Tensor instance.
///
/// \return The created Variable node handle.
MIND_C_API NodeHandle MSNewTensorVariableFromTensor(ResMgrHandle res_mgr, GraphHandle graph, TensorHandle tensor);
MIND_C_API NodeHandle MSNewTensorVariableFromTensor(ResMgrHandle res_mgr, GraphHandle graph, ConstTensorHandle tensor);
/// \brief Get data size of a Tensor Variable.
///
@ -147,7 +171,7 @@ MIND_C_API NodeHandle MSNewTensorVariableFromTensor(ResMgrHandle res_mgr, GraphH
/// \param[in] error Records error code that indicate whether the functions executed successfully.
///
/// \return The data byte size.
MIND_C_API size_t MSTensorVariableGetDataSize(ResMgrHandle res_mgr, NodeHandle node, STATUS *error);
MIND_C_API size_t MSTensorVariableGetDataSize(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error);
/// \brief Get data from a Tensor Variable.
///
@ -155,7 +179,7 @@ MIND_C_API size_t MSTensorVariableGetDataSize(ResMgrHandle res_mgr, NodeHandle n
/// \param[in] node The tensor variable node
///
/// \return The data.
MIND_C_API void *MSTensorVariableGetData(ResMgrHandle res_mgr, NodeHandle node);
MIND_C_API void *MSTensorVariableGetData(ResMgrHandle res_mgr, ConstNodeHandle node);
/// \brief Create a Constant node of tensor, which contains constant tensor data.
///
@ -185,7 +209,7 @@ MIND_C_API NodeHandle MSNewTensorConstantFromTensor(ResMgrHandle res_mgr, Tensor
/// \param[in] error Records error code that indicate whether the functions executed successfully.
///
/// \return The data byte size.
MIND_C_API size_t MSTensorConstantGetDataSize(ResMgrHandle res_mgr, NodeHandle node, STATUS *error);
MIND_C_API size_t MSTensorConstantGetDataSize(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error);
/// \brief Get data from a Tensor Constant.
///
@ -193,7 +217,7 @@ MIND_C_API size_t MSTensorConstantGetDataSize(ResMgrHandle res_mgr, NodeHandle n
/// \param[in] node The tensor constant node
///
/// \return The data.
MIND_C_API void *MSTensorConstantGetData(ResMgrHandle res_mgr, NodeHandle node);
MIND_C_API void *MSTensorConstantGetData(ResMgrHandle res_mgr, ConstNodeHandle node);
/// \brief Create Constant node of a float scalar.
///
@ -259,7 +283,7 @@ MIND_C_API NodeHandle MSNewTypeConstant(ResMgrHandle res_mgr, TypeId type);
/// \param[in] error Records error code that indicate whether the functions executed successfully.
///
/// \return The obtained int32 value.
MIND_C_API int MSScalarConstantGetValueInt32(ResMgrHandle res_mgr, const NodeHandle node, STATUS *error);
MIND_C_API int MSScalarConstantGetValueInt32(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error);
/// \brief Get value from the float32 scalar Constant node.
///
@ -268,7 +292,7 @@ MIND_C_API int MSScalarConstantGetValueInt32(ResMgrHandle res_mgr, const NodeHan
/// \param[in] error Records error code that indicate whether the functions executed successfully.
///
/// \return The obtained float32 value.
MIND_C_API float MSScalarConstantGetValueFloat32(ResMgrHandle res_mgr, const NodeHandle node, STATUS *error);
MIND_C_API float MSScalarConstantGetValueFloat32(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error);
/// \brief Get value from the bool scalar Constant node.
///
@ -277,7 +301,7 @@ MIND_C_API float MSScalarConstantGetValueFloat32(ResMgrHandle res_mgr, const Nod
/// \param[in] error Records error code that indicate whether the functions executed successfully.
///
/// \return The obtained bool value.
MIND_C_API bool MSScalarConstantGetValueBool(ResMgrHandle res_mgr, const NodeHandle node, STATUS *error);
MIND_C_API bool MSScalarConstantGetValueBool(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error);
/// \brief Get value from the int64 scalar Constant node.
///
@ -286,7 +310,7 @@ MIND_C_API bool MSScalarConstantGetValueBool(ResMgrHandle res_mgr, const NodeHan
/// \param[in] error Records error code that indicate whether the functions executed successfully.
///
/// \return The obtained int64 value.
MIND_C_API int64_t MSScalarConstantGetValueInt64(ResMgrHandle res_mgr, const NodeHandle node, STATUS *error);
MIND_C_API int64_t MSScalarConstantGetValueInt64(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error);
/// \brief Get value from the string Constant node.
///
@ -296,7 +320,7 @@ MIND_C_API int64_t MSScalarConstantGetValueInt64(ResMgrHandle res_mgr, const Nod
/// \param[in] str_len The size of the char array.
///
/// \return The error code that indicate whether the functions executed successfully.
MIND_C_API STATUS MSStringConstantGetValue(ResMgrHandle res_mgr, const NodeHandle node, char str_buf[], size_t str_len);
MIND_C_API STATUS MSStringConstantGetValue(ResMgrHandle res_mgr, ConstNodeHandle node, char str_buf[], size_t str_len);
/// \brief Get value from the tuple Constant node.
///
@ -305,7 +329,7 @@ MIND_C_API STATUS MSStringConstantGetValue(ResMgrHandle res_mgr, const NodeHandl
/// \param[in] error Records error code that indicate whether the functions executed successfully.
///
/// \return The size of the Tuple.
MIND_C_API size_t MSTupleConstantGetSize(ResMgrHandle res_mgr, const NodeHandle node, STATUS *error);
MIND_C_API size_t MSTupleConstantGetSize(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error);
/// \brief Get value from the Tuple Constant node.
///
@ -315,7 +339,7 @@ MIND_C_API size_t MSTupleConstantGetSize(ResMgrHandle res_mgr, const NodeHandle
/// \param[in] size The size of the value vector.
///
/// \return The error code that indicate whether the functions executed successfully.
MIND_C_API STATUS MSTupleConstantGetValueInt64(ResMgrHandle res_mgr, const NodeHandle node, int64_t vec[], size_t size);
MIND_C_API STATUS MSTupleConstantGetValueInt64(ResMgrHandle res_mgr, ConstNodeHandle node, int64_t vec[], size_t size);
/// \brief Get value from the Type Constant node.
///
@ -324,7 +348,7 @@ MIND_C_API STATUS MSTupleConstantGetValueInt64(ResMgrHandle res_mgr, const NodeH
/// \param[in] error Records error code that indicate whether the functions executed successfully.
///
/// \return The obtained type value.
MIND_C_API TypeId MSTypeConstantGetValue(ResMgrHandle res_mgr, const NodeHandle node, STATUS *error);
MIND_C_API TypeId MSTypeConstantGetValue(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error);
/// \brief Set Operator node name.
///
@ -333,7 +357,7 @@ MIND_C_API TypeId MSTypeConstantGetValue(ResMgrHandle res_mgr, const NodeHandle
/// \param[in] name The op node name to be set.
///
/// \return The error code that indicate whether the functions executed successfully.
MIND_C_API STATUS MSOpSetName(ResMgrHandle res_mgr, const NodeHandle node, const char *name);
MIND_C_API STATUS MSOpSetName(ResMgrHandle res_mgr, NodeHandle node, const char *name);
/// \brief Get the name of node.
///
@ -343,7 +367,7 @@ MIND_C_API STATUS MSOpSetName(ResMgrHandle res_mgr, const NodeHandle node, const
/// \param[in] str_len The size of the char array.
///
/// \return The error code that indicate whether the functions executed successfully.
MIND_C_API STATUS MSNodeGetName(ResMgrHandle res_mgr, const NodeHandle node, char str_buf[], size_t str_len);
MIND_C_API STATUS MSNodeGetName(ResMgrHandle res_mgr, ConstNodeHandle node, char str_buf[], size_t str_len);
#ifdef __cplusplus
}

View File

@ -73,7 +73,7 @@ MIND_C_API TensorHandle MSNewTensorWithSrcType(ResMgrHandle res_mgr, void *data,
/// \param[in] tensor The pointer of the tensor instance.
///
/// \return The pointer to the tensor data
MIND_C_API void *MSTensorGetData(ResMgrHandle res_mgr, const TensorHandle tensor);
MIND_C_API void *MSTensorGetData(ResMgrHandle res_mgr, ConstTensorHandle tensor);
/// \brief Set tensor data type.
///
@ -82,7 +82,7 @@ MIND_C_API void *MSTensorGetData(ResMgrHandle res_mgr, const TensorHandle tensor
/// \param[in] type The data type to be set.
///
/// \return Error code that indicate whether the functions executed successfully.
MIND_C_API STATUS MSTensorSetDataType(ResMgrHandle res_mgr, const TensorHandle tensor, TypeId type);
MIND_C_API STATUS MSTensorSetDataType(ResMgrHandle res_mgr, TensorHandle tensor, TypeId type);
/// \brief Get tensor data type.
///
@ -90,7 +90,7 @@ MIND_C_API STATUS MSTensorSetDataType(ResMgrHandle res_mgr, const TensorHandle t
/// \param[in] tensor The pointer of the tensor instance.
///
/// \return The data type of tensor.
MIND_C_API TypeId MSTensorGetDataType(ResMgrHandle res_mgr, const TensorHandle tensor, STATUS *error);
MIND_C_API TypeId MSTensorGetDataType(ResMgrHandle res_mgr, ConstTensorHandle tensor, STATUS *error);
/// \brief Get the byte size of tensor data.
///
@ -99,7 +99,7 @@ MIND_C_API TypeId MSTensorGetDataType(ResMgrHandle res_mgr, const TensorHandle t
/// \param[in] error Records error code that indicate whether the functions executed successfully.
///
/// \return The byte size of tensor data.
MIND_C_API size_t MSTensorGetDataSize(ResMgrHandle res_mgr, const TensorHandle tensor, STATUS *error);
MIND_C_API size_t MSTensorGetDataSize(ResMgrHandle res_mgr, ConstTensorHandle tensor, STATUS *error);
/// \brief Get the element number of tensor array.
///
@ -108,7 +108,7 @@ MIND_C_API size_t MSTensorGetDataSize(ResMgrHandle res_mgr, const TensorHandle t
/// \param[in] error Records error code that indicate whether the functions executed successfully.
///
/// \return The element number of tensor array.
MIND_C_API size_t MSTensorGetElementNum(ResMgrHandle res_mgr, const TensorHandle tensor, STATUS *error);
MIND_C_API size_t MSTensorGetElementNum(ResMgrHandle res_mgr, ConstTensorHandle tensor, STATUS *error);
/// \brief Get the dimension of tensor.
///
@ -117,7 +117,7 @@ MIND_C_API size_t MSTensorGetElementNum(ResMgrHandle res_mgr, const TensorHandle
/// \param[in] error Records error code that indicate whether the functions executed successfully.
///
/// \return The dimension of tensor.
MIND_C_API size_t MSTensorGetDimension(ResMgrHandle res_mgr, const TensorHandle tensor, STATUS *error);
MIND_C_API size_t MSTensorGetDimension(ResMgrHandle res_mgr, ConstTensorHandle tensor, STATUS *error);
/// \brief Set the shape of tensor array.
///
@ -127,7 +127,7 @@ MIND_C_API size_t MSTensorGetDimension(ResMgrHandle res_mgr, const TensorHandle
/// \param[in] dim The the dimension of tensor, i.e., size of shape array.
///
/// \return Error code indicates whether the function executed successfully.
MIND_C_API STATUS MSTensorSetShape(ResMgrHandle res_mgr, const TensorHandle tensor, int64_t shape[], size_t dim);
MIND_C_API STATUS MSTensorSetShape(ResMgrHandle res_mgr, TensorHandle tensor, const int64_t shape[], size_t dim);
/// \brief Get the shape of tensor array.
///
@ -137,7 +137,7 @@ MIND_C_API STATUS MSTensorSetShape(ResMgrHandle res_mgr, const TensorHandle tens
/// \param[in] dim The the dimension of tensor, i.e., size of shape array.
///
/// \return Error code indicates whether the function executed successfully.
MIND_C_API STATUS MSTensorGetShape(ResMgrHandle res_mgr, const TensorHandle tensor, int64_t shape[], size_t dim);
MIND_C_API STATUS MSTensorGetShape(ResMgrHandle res_mgr, ConstTensorHandle tensor, int64_t shape[], size_t dim);
#ifdef __cplusplus
}

View File

@ -21,7 +21,7 @@
#include "abstract/dshape.h"
#include "ir/dtype.h"
STATUS MSAssignAbstract(ResMgrHandle res_mgr, NodeHandle cur_node, NodeHandle input_node) {
STATUS MSAssignAbstract(ResMgrHandle res_mgr, NodeHandle cur_node, ConstNodeHandle input_node) {
if (res_mgr == nullptr || cur_node == nullptr || input_node == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [inputs] are nullptr.";
return RET_NULL_PTR;

View File

@ -20,7 +20,7 @@
#include "c_api/src/common.h"
#include "ir/tensor.h"
PrimitivePtr GetOpPrim(ResMgrHandle res_mgr, NodeHandle node) {
PrimitivePtr GetOpPrim(ResMgrHandle res_mgr, ConstNodeHandle node) {
auto src_node = GetSrcPtr<CNodePtr>(res_mgr, node);
auto node_input = src_node->input(0);
if (node_input == nullptr) {
@ -213,7 +213,7 @@ STATUS MSOpSetAttrString(ResMgrHandle res_mgr, NodeHandle op, const char *attr_n
return RET_OK;
}
int64_t MSOpGetScalarAttrInt64(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, STATUS *error) {
int64_t MSOpGetScalarAttrInt64(ResMgrHandle res_mgr, ConstNodeHandle op, const char *attr_name, STATUS *error) {
if (error == nullptr) {
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
return 0;
@ -240,7 +240,7 @@ int64_t MSOpGetScalarAttrInt64(ResMgrHandle res_mgr, NodeHandle op, const char *
}
}
STATUS MSOpGetAttrArrayInt64(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, int64_t values[],
STATUS MSOpGetAttrArrayInt64(ResMgrHandle res_mgr, ConstNodeHandle op, const char *attr_name, int64_t values[],
size_t value_num) {
if (res_mgr == nullptr || op == nullptr || attr_name == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] is nullptr.";

View File

@ -18,12 +18,14 @@
#define MINDSPORE_CCSRC_C_API_SRC_COMMON_H_
#include <string>
#include <memory>
#include "ir/func_graph.h"
#include "ops/primitive_c.h"
using FuncGraphImpl = mindspore::FuncGraph;
using FuncGraphManagerImpl = mindspore::FuncGraphManager;
using AnfNodeImpl = mindspore::AnfNode;
using ParameterImpl = mindspore::Parameter;
using ValueNodeImpl = mindspore::ValueNode;
using CNodeImpl = mindspore::CNode;
using PrimitiveImpl = mindspore::Primitive;
@ -64,6 +66,7 @@ using CNodePtr = mindspore::CNodePtr;
using ParameterPtr = mindspore::ParameterPtr;
using AbstractBasePtr = mindspore::abstract::AbstractBasePtr;
using FuncGraphPtr = mindspore::FuncGraphPtr;
using FuncGraphManagerPtr = std::shared_ptr<mindspore::FuncGraphManager>;
using AttrMap = mindspore::HashMap<std::string, ValuePtr>;

View File

@ -21,8 +21,11 @@
#include "base/base.h"
#include "ops/core_ops.h"
#include "ir/func_graph.h"
#include "ir/anf.h"
#include "ir/func_graph_cloner.h"
#include "utils/ms_context.h"
#include "backend/graph_compiler/backend.h"
#include "pipeline/jit/pass.h"
GraphHandle MSFuncGraphCreate(ResMgrHandle res_mgr) {
if (res_mgr == nullptr) {
@ -33,7 +36,7 @@ GraphHandle MSFuncGraphCreate(ResMgrHandle res_mgr) {
return GetRawPtr(res_mgr, fg);
}
NodeHandle MSFuncGraphGetInput(ResMgrHandle res_mgr, const GraphHandle graph, size_t i) {
NodeHandle MSFuncGraphGetInput(ResMgrHandle res_mgr, ConstGraphHandle graph, size_t i) {
if (res_mgr == nullptr || graph == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [cnode] is nullptr.";
return nullptr;
@ -53,7 +56,7 @@ NodeHandle MSFuncGraphGetInput(ResMgrHandle res_mgr, const GraphHandle graph, si
}
}
size_t MSFuncGraphGetInputNum(ResMgrHandle res_mgr, const GraphHandle graph, STATUS *error) {
size_t MSFuncGraphGetInputNum(ResMgrHandle res_mgr, ConstGraphHandle graph, STATUS *error) {
if (error == nullptr) {
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
return 0;
@ -77,7 +80,7 @@ size_t MSFuncGraphGetInputNum(ResMgrHandle res_mgr, const GraphHandle graph, STA
return input_num;
}
STATUS MSFuncGraphGetInputs(ResMgrHandle res_mgr, const GraphHandle graph, NodeHandle inputs[], size_t input_num) {
STATUS MSFuncGraphGetInputs(ResMgrHandle res_mgr, ConstGraphHandle graph, NodeHandle inputs[], size_t input_num) {
if (res_mgr == nullptr || graph == nullptr || inputs == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [inputs] is nullptr.";
return RET_NULL_PTR;
@ -100,7 +103,7 @@ STATUS MSFuncGraphGetInputs(ResMgrHandle res_mgr, const GraphHandle graph, NodeH
return RET_OK;
}
STATUS MSFuncGraphSetOutput(ResMgrHandle res_mgr, GraphHandle graph, const NodeHandle op_node, bool force_new_ret) {
STATUS MSFuncGraphSetOutput(ResMgrHandle res_mgr, GraphHandle graph, ConstNodeHandle op_node, bool force_new_ret) {
if (res_mgr == nullptr || graph == nullptr || op_node == nullptr) {
MS_LOG(ERROR) << "Input GraphHandle [res_mgr] or [graph] or [op_node] is nullptr.";
return RET_NULL_PTR;
@ -118,7 +121,7 @@ STATUS MSFuncGraphSetOutput(ResMgrHandle res_mgr, GraphHandle graph, const NodeH
return RET_OK;
}
STATUS MSFuncGraphSetOutputs(ResMgrHandle res_mgr, GraphHandle graph, const Handle outputs[], size_t output_num,
STATUS MSFuncGraphSetOutputs(ResMgrHandle res_mgr, GraphHandle graph, Handle const outputs[], size_t output_num,
bool force_new_ret) {
if (res_mgr == nullptr || graph == nullptr || outputs == nullptr) {
MS_LOG(ERROR) << "Input GraphHandle [res_mgr] or [graph] or [outputs] is nullptr.";
@ -146,7 +149,7 @@ STATUS MSFuncGraphSetOutputs(ResMgrHandle res_mgr, GraphHandle graph, const Hand
return RET_OK;
}
NodeHandle MSFuncGraphGetOutput(ResMgrHandle res_mgr, const GraphHandle graph, size_t i) {
NodeHandle MSFuncGraphGetOutput(ResMgrHandle res_mgr, ConstGraphHandle graph, size_t i) {
if (res_mgr == nullptr || graph == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] is nullptr.";
return nullptr;
@ -178,7 +181,7 @@ NodeHandle MSFuncGraphGetOutput(ResMgrHandle res_mgr, const GraphHandle graph, s
}
}
size_t MSFuncGraphGetOutputNum(ResMgrHandle res_mgr, GraphHandle graph, STATUS *error) {
size_t MSFuncGraphGetOutputNum(ResMgrHandle res_mgr, ConstGraphHandle graph, STATUS *error) {
if (error == nullptr) {
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
return 0;
@ -208,7 +211,7 @@ size_t MSFuncGraphGetOutputNum(ResMgrHandle res_mgr, GraphHandle graph, STATUS *
return out_num;
}
STATUS MSFuncGraphGetOutputs(ResMgrHandle res_mgr, const GraphHandle graph, NodeHandle outputs[], size_t output_num) {
STATUS MSFuncGraphGetOutputs(ResMgrHandle res_mgr, ConstGraphHandle graph, NodeHandle outputs[], size_t output_num) {
if (res_mgr == nullptr || graph == nullptr || outputs == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [inputs] is nullptr.";
return RET_NULL_PTR;
@ -239,8 +242,7 @@ STATUS MSFuncGraphGetOutputs(ResMgrHandle res_mgr, const GraphHandle graph, Node
return RET_OK;
}
STATUS MSFuncGraphReplace(ResMgrHandle res_mgr, GraphHandle graph, const NodeHandle old_node,
const NodeHandle new_node) {
STATUS MSFuncGraphReplace(ResMgrHandle res_mgr, GraphHandle graph, ConstNodeHandle old_node, ConstNodeHandle new_node) {
if (res_mgr == nullptr || graph == nullptr || old_node == nullptr || new_node == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [old_node] or [new_node] is nullptr.";
return RET_NULL_PTR;
@ -272,9 +274,11 @@ STATUS MSFuncGraphCompile(ResMgrHandle res_mgr, GraphHandle graph) {
try {
auto func_graph = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<FuncGraphPtr> func_graphs = {func_graph};
auto fg_mgr = mindspore::Manage(func_graphs, true);
auto fg_mgr = mindspore::MakeManager();
fg_mgr->AddFuncGraph(func_graph, true);
MS_EXCEPTION_IF_NULL(fg_mgr);
func_graph->set_manager(fg_mgr);
(void)mindspore::LiftingClone(func_graph);
context_ptr->Refresh();
std::string backend_name = context_ptr->backend_policy();
std::string target = context_ptr->get_param<std::string>(mindspore::MS_CTX_DEVICE_TARGET);
@ -295,7 +299,7 @@ STATUS MSFuncGraphCompile(ResMgrHandle res_mgr, GraphHandle graph) {
return RET_OK;
}
STATUS MSFuncGraphRun(ResMgrHandle res_mgr, GraphHandle graph, const TensorHandle inputs[], size_t input_num,
STATUS MSFuncGraphRun(ResMgrHandle res_mgr, GraphHandle graph, TensorHandle const inputs[], size_t input_num,
TensorHandle outputs[], size_t outputs_num) {
if (res_mgr == nullptr || inputs == nullptr || outputs == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [inputs] or [outputs] is nullptr.";

View File

@ -26,7 +26,7 @@
Handle GetRawPtr(ResMgrHandle res_mgr, const BasePtr &src_ptr);
template <typename T>
T GetSrcPtr(ResMgrHandle res_mgr, Handle raw_ptr) {
T GetSrcPtr(ResMgrHandle res_mgr, ConstHandle raw_ptr) {
auto res_mgr_ptr = reinterpret_cast<ResourceManager *>(res_mgr);
BasePtr base_ptr = res_mgr_ptr->GetSrcPtr(raw_ptr);
if (base_ptr == nullptr) {

View File

@ -21,8 +21,15 @@
#include "base/base.h"
#include "ops/core_ops.h"
#include "ir/param_info.h"
#include "ir/anf.h"
#include "ir/scope.h"
#include "ir/func_graph_cloner.h"
#include "backend/common/optimizer/helper.h"
constexpr size_t firstInIdx = 1;
constexpr size_t secondInIdx = 2;
constexpr size_t switchInputNum = 3;
STATUS SetAttrs(ResMgrHandle res_mgr, const PrimitivePtr &prim, char **attr_names, AttrHandle attrs[],
size_t attr_num) {
AttrMap attr_map{};
@ -48,7 +55,7 @@ STATUS SetAttrs(ResMgrHandle res_mgr, const PrimitivePtr &prim, char **attr_name
return RET_OK;
}
NodeHandle MSNewOp(ResMgrHandle res_mgr, GraphHandle graph, const char *op_type, const Handle inputs[],
NodeHandle MSNewOp(ResMgrHandle res_mgr, GraphHandle graph, const char *op_type, Handle const inputs[],
size_t input_num, char **attr_names, AttrHandle attrs[], size_t attr_num) {
if (res_mgr == nullptr || graph == nullptr || op_type == nullptr || inputs == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [op_type] or [inputs] is nullptr.";
@ -78,11 +85,15 @@ NodeHandle MSNewOp(ResMgrHandle res_mgr, GraphHandle graph, const char *op_type,
for (size_t i = 0; i < input_num; ++i) {
auto input = GetSrcPtr<AnfNodePtr>(res_mgr, inputs[i]);
MS_EXCEPTION_IF_NULL(input);
if (input->isa<ParameterImpl>() && input->func_graph() != res_fg) {
res_fg->AddFreeVariable(input);
}
ConvertConstScalarInputToTensor(input);
cnode_inputs.push_back(input);
abs_list.push_back(input->abstract());
}
cnode = res_fg->NewCNode(cnode_inputs);
MS_EXCEPTION_IF_NULL(cnode);
if (res_mgr_ptr->GetInfer()) {
auto out_abs = mindspore::opt::CppInferShapeAndType(prim, abs_list);
cnode->set_abstract(out_abs);
@ -95,7 +106,7 @@ NodeHandle MSNewOp(ResMgrHandle res_mgr, GraphHandle graph, const char *op_type,
return GetRawPtr(res_mgr, cnode);
}
NodeHandle MSPackNodesTuple(ResMgrHandle res_mgr, GraphHandle graph, const Handle nodes[], size_t node_num) {
NodeHandle MSPackNodesTuple(ResMgrHandle res_mgr, GraphHandle graph, Handle const nodes[], size_t node_num) {
if (res_mgr == nullptr || graph == nullptr || nodes == nullptr) {
MS_LOG(ERROR) << "Input GraphHandle [res_mgr] or [graph] or [nodes] is nullptr.";
return nullptr;
@ -114,6 +125,7 @@ NodeHandle MSPackNodesTuple(ResMgrHandle res_mgr, GraphHandle graph, const Handl
abs_list.push_back(in_node->abstract());
}
make_tuple_cnode = res_fg->NewCNode(in_nodes);
MS_EXCEPTION_IF_NULL(make_tuple_cnode);
make_tuple_cnode->set_abstract(std::make_shared<AbstractTupleImpl>(abs_list));
} catch (const std::exception &e) {
MS_LOG(ERROR) << "FuncGraph set output failed. Error info: " << e.what();
@ -122,7 +134,7 @@ NodeHandle MSPackNodesTuple(ResMgrHandle res_mgr, GraphHandle graph, const Handl
return GetRawPtr(res_mgr, make_tuple_cnode);
}
NodeHandle MSOpGetSpecOutput(ResMgrHandle res_mgr, GraphHandle graph, const NodeHandle op, size_t i) {
NodeHandle MSOpGetSpecOutput(ResMgrHandle res_mgr, GraphHandle graph, ConstNodeHandle op, size_t i) {
if (res_mgr == nullptr || graph == nullptr || op == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] is nullptr.";
return nullptr;
@ -148,6 +160,7 @@ NodeHandle MSOpGetSpecOutput(ResMgrHandle res_mgr, GraphHandle graph, const Node
auto abs_scalar = std::make_shared<mindspore::abstract::AbstractScalar>(mindspore::SizeToInt(i));
idx->set_abstract(abs_scalar);
ret_node = res_fg->NewCNode({NewValueNode(mindspore::prim::kPrimTupleGetItem), cnode, idx});
MS_EXCEPTION_IF_NULL(ret_node);
ret_node->set_abstract(abs->cast<mindspore::abstract::AbstractTuplePtr>()->elements()[i]);
} else {
if (i >= 1) {
@ -167,7 +180,229 @@ NodeHandle MSOpGetSpecOutput(ResMgrHandle res_mgr, GraphHandle graph, const Node
return GetRawPtr(res_mgr, ret_node);
}
NodeHandle MSOpGetInput(ResMgrHandle res_mgr, const NodeHandle op, size_t i) {
CNodePtr BuildSwitchStructure(ResMgrHandle res_mgr, GraphHandle graph, NodeHandle const switch_input[],
size_t input_num, bool set_fg_out) {
MS_EXCEPTION_IF_NULL(res_mgr);
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(switch_input);
MS_EXCEPTION_IF_CHECK_FAIL(input_num == switchInputNum, "Switch's input number must be 3!");
NodeHandle switch_op = MSNewOp(res_mgr, graph, "Switch", switch_input, input_num, NULL, NULL, 0);
auto src_switch = GetSrcPtr<CNodePtr>(res_mgr, switch_op);
MS_EXCEPTION_IF_NULL(src_switch);
auto fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
MS_EXCEPTION_IF_NULL(fg);
CNodePtr switch_call = fg->NewCNode({src_switch});
MS_EXCEPTION_IF_NULL(switch_call);
if (set_fg_out) {
fg->set_output(switch_call);
}
auto first_node = GetSrcPtr<ValueNodePtr>(res_mgr, switch_input[firstInIdx]);
MS_EXCEPTION_IF_NULL(first_node);
auto second_node = GetSrcPtr<ValueNodePtr>(res_mgr, switch_input[secondInIdx]);
MS_EXCEPTION_IF_NULL(second_node);
// AddFuncGraphCNodeIndex is used to set cnode_index. A funcgraph's cnode_index is a list of pair
// with pair-struct is (CNODE, index). The CNODE is in another funcgraph, who uses the funcgraph as its input.
// for eg. if fg1's cnode A uses fg2 as A's first input, then fg2's conde_index is (A, 1)
if (first_node->isa<ValueNodeImpl>()) {
fg->AddValueNode(first_node);
if (mindspore::IsValueNode<FuncGraphImpl>(first_node)) {
auto used = mindspore::GetValueNode<FuncGraphPtr>(first_node);
used->AddFuncGraphCNodeIndex(
std::make_shared<mindspore::CNodeIndexPair>(std::make_pair(src_switch, firstInIdx + 1)));
(void)fg->AddFuncGraphUsed(used);
}
}
if (second_node->isa<ValueNodeImpl>()) {
fg->AddValueNode(second_node);
if (mindspore::IsValueNode<FuncGraphImpl>(second_node)) {
auto used = mindspore::GetValueNode<FuncGraphPtr>(second_node);
used->AddFuncGraphCNodeIndex(
std::make_shared<mindspore::CNodeIndexPair>(std::make_pair(src_switch, secondInIdx + 1)));
(void)fg->AddFuncGraphUsed(used);
}
}
// Switch-call's abstract is equal to second branch.
if (mindspore::IsValueNode<FuncGraphImpl>(second_node)) {
auto sub_fg = mindspore::GetValueNode<FuncGraphPtr>(second_node);
switch_call->set_abstract(sub_fg->output()->abstract());
}
return switch_call;
}
NodeHandle MSNewSwitch(ResMgrHandle res_mgr, GraphHandle graph, Handle cond, ConstGraphHandle true_br,
ConstGraphHandle false_br) {
if (res_mgr == nullptr || graph == nullptr || cond == nullptr || true_br == nullptr || false_br == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [cond] or [true_br] or [false_br] is nullptr.";
return nullptr;
}
try {
auto src_cond = GetSrcPtr<BasePtr>(res_mgr, cond);
MS_EXCEPTION_IF_NULL(src_cond);
NodeHandle cond_raw_ptr = nullptr;
if (src_cond->isa<FuncGraphImpl>()) {
auto cond_graph = src_cond->cast<FuncGraphPtr>();
MS_EXCEPTION_IF_NULL(cond_graph);
auto cond_node = mindspore::NewValueNode(cond_graph);
cond_node->set_abstract(cond_graph->ToAbstract());
cond_raw_ptr = GetRawPtr(res_mgr, cond_node);
} else {
cond_raw_ptr = cond;
}
auto true_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, true_br);
MS_EXCEPTION_IF_NULL(true_fg);
auto true_node = mindspore::NewValueNode(true_fg);
true_node->set_abstract(true_fg->ToAbstract());
NodeHandle true_raw_ptr = GetRawPtr(res_mgr, true_node);
auto false_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, false_br);
MS_EXCEPTION_IF_NULL(false_fg);
auto false_node = mindspore::NewValueNode(false_fg);
false_node->set_abstract(false_fg->ToAbstract());
NodeHandle false_raw_ptr = GetRawPtr(res_mgr, false_node);
NodeHandle switch_input[] = {cond_raw_ptr, true_raw_ptr, false_raw_ptr};
auto switch_call = BuildSwitchStructure(res_mgr, graph, switch_input, switchInputNum, false);
MS_EXCEPTION_IF_NULL(switch_call);
return GetRawPtr(res_mgr, switch_call);
} catch (const std::exception &e) {
MS_LOG(ERROR) << "New Switch node failed. Error info: " << e.what();
return nullptr;
}
}
void HandleFVInWhileGraph(const FuncGraphPtr &main_fg, const FuncGraphPtr &body_fg, const FuncGraphPtr &after_fg) {
std::vector<AnfNodePtr> fv_to_restore{};
auto body_fvs = body_fg->free_variables();
for (const auto &fv : body_fvs) {
auto fv_node = fv.first;
MS_EXCEPTION_IF_NULL(fv_node);
if (fv_node->func_graph() != main_fg &&
std::find(fv_to_restore.begin(), fv_to_restore.end(), fv_node) == fv_to_restore.end()) {
fv_to_restore.push_back(fv_node);
}
}
auto after_fvs = after_fg->free_variables();
for (const auto &fv : after_fvs) {
auto fv_node = fv.first;
MS_EXCEPTION_IF_NULL(fv_node);
if (fv_node->func_graph() != main_fg &&
std::find(fv_to_restore.begin(), fv_to_restore.end(), fv_node) == fv_to_restore.end()) {
fv_to_restore.push_back(fv_node);
}
}
(void)mindspore::LiftingClone(main_fg);
auto main_manager = Manage(main_fg);
std::vector<AnfNodePtr> new_main_params{};
auto main_params = main_fg->parameters();
for (const auto &main_param : main_params) {
auto src_main_param = main_param->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(src_main_param);
auto found_in_fv_list =
find_if(fv_to_restore.begin(), fv_to_restore.end(), [&main_param](const AnfNodePtr &fv_param) {
return !main_param->ToString().empty() && main_param->ToString() == fv_param->ToString();
});
if (found_in_fv_list != fv_to_restore.end()) {
(void)main_manager->Replace(main_param, *found_in_fv_list);
} else if (src_main_param->has_default()) {
auto const_input = mindspore::NewValueNode(src_main_param->default_param());
const_input->set_abstract(src_main_param->abstract());
(void)main_manager->Replace(main_param, const_input);
} else {
new_main_params.push_back(main_param);
}
}
main_fg->set_parameters(new_main_params);
}
NodeHandle MSNewWhile(ResMgrHandle res_mgr, GraphHandle graph, Handle cond, GraphHandle body_graph,
GraphHandle after_graph) {
if (res_mgr == nullptr || graph == nullptr || cond == nullptr || body_graph == nullptr || after_graph == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [cond] or [body_graph] or [after_graph] is nullptr.";
return nullptr;
}
try {
auto src_cond = GetSrcPtr<BasePtr>(res_mgr, cond);
MS_EXCEPTION_IF_NULL(src_cond);
NodeHandle cond_raw_ptr = nullptr;
GraphHandle cond_graph = nullptr;
FuncGraphPtr src_cond_graph = nullptr;
if (src_cond->isa<FuncGraphImpl>()) {
cond_graph = cond;
src_cond_graph = src_cond->cast<FuncGraphPtr>();
MS_EXCEPTION_IF_NULL(src_cond_graph);
auto cond_node = src_cond_graph->output();
MS_EXCEPTION_IF_NULL(cond_node);
cond_raw_ptr = GetRawPtr(res_mgr, cond_node);
} else {
cond_graph = MSFuncGraphCreate(res_mgr);
MS_EXCEPTION_IF_NULL(cond_graph);
src_cond_graph = GetSrcPtr<FuncGraphPtr>(res_mgr, cond_graph);
MS_EXCEPTION_IF_NULL(src_cond_graph);
if (src_cond->isa<CNodeImpl>()) {
auto cond_node = src_cond->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cond_node);
auto new_cond = src_cond_graph->NewCNode(cond_node->inputs());
MS_EXCEPTION_IF_NULL(new_cond);
new_cond->set_abstract(cond_node->abstract());
cond_raw_ptr = GetRawPtr(res_mgr, new_cond);
} else {
cond_raw_ptr = cond;
}
}
auto body_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, body_graph);
MS_EXCEPTION_IF_NULL(body_fg);
auto body_node = mindspore::NewValueNode(body_fg);
body_node->set_abstract(body_fg->ToAbstract());
NodeHandle body_raw_ptr = GetRawPtr(res_mgr, body_node);
auto after_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, after_graph);
MS_EXCEPTION_IF_NULL(after_fg);
auto after_node = mindspore::NewValueNode(after_fg);
after_node->set_abstract(after_fg->ToAbstract());
NodeHandle after_raw_ptr = GetRawPtr(res_mgr, after_node);
NodeHandle switch_input[] = {cond_raw_ptr, body_raw_ptr, after_raw_ptr};
(void)BuildSwitchStructure(res_mgr, cond_graph, switch_input, switchInputNum, true);
// handle main graph call
auto main_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, graph);
NodeHandle main_func_call = MSNewFuncCallNode(res_mgr, graph, cond_graph, nullptr, 0);
auto src_call = GetSrcPtr<AnfNodePtr>(res_mgr, main_func_call);
main_fg->set_output(src_call);
// handle free parameters in while graphs
HandleFVInWhileGraph(main_fg, body_fg, after_fg);
// handle multi outputs in body graph
auto sub_graph_node = mindspore::NewValueNode(src_cond_graph);
sub_graph_node->set_abstract(src_cond_graph->ToAbstract());
std::vector<AnfNodePtr> sub_input_nodes{sub_graph_node};
auto body_out_node = body_fg->output();
MS_EXCEPTION_IF_NULL(body_out_node);
if (IsPrimitiveCNode(body_out_node, mindspore::prim::kPrimMakeTuple)) {
auto body_out_cnode = body_out_node->cast<CNodePtr>();
for (size_t i = 1; i < body_out_cnode->size(); i++) {
sub_input_nodes.push_back(body_out_cnode->input(i));
}
} else {
sub_input_nodes.push_back(body_out_node);
}
auto body_func_call = body_fg->NewCNode(sub_input_nodes);
MS_EXCEPTION_IF_NULL(src_cond_graph->output());
body_func_call->set_abstract(src_cond_graph->output()->abstract());
MS_EXCEPTION_IF_NULL(body_func_call);
body_fg->set_output(body_func_call);
return main_func_call;
} catch (const std::exception &e) {
MS_LOG(ERROR) << "New While node failed. Error info: " << e.what();
return nullptr;
}
}
NodeHandle MSOpGetInput(ResMgrHandle res_mgr, ConstNodeHandle op, size_t i) {
if (res_mgr == nullptr || op == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] is nullptr.";
return nullptr;
@ -188,7 +423,7 @@ NodeHandle MSOpGetInput(ResMgrHandle res_mgr, const NodeHandle op, size_t i) {
return GetRawPtr(res_mgr, anf_node);
}
size_t MSOpGetInputsNum(ResMgrHandle res_mgr, const NodeHandle op, STATUS *error) {
size_t MSOpGetInputsNum(ResMgrHandle res_mgr, ConstNodeHandle op, STATUS *error) {
if (error == nullptr) {
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
return 0;
@ -212,7 +447,7 @@ size_t MSOpGetInputsNum(ResMgrHandle res_mgr, const NodeHandle op, STATUS *error
return input_num;
}
STATUS MSOpGetInputs(ResMgrHandle res_mgr, const NodeHandle op, NodeHandle inputs[], size_t input_num) {
STATUS MSOpGetInputs(ResMgrHandle res_mgr, ConstNodeHandle op, NodeHandle inputs[], size_t input_num) {
if (res_mgr == nullptr || op == nullptr || inputs == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [inputs] is nullptr.";
return RET_NULL_PTR;
@ -236,10 +471,10 @@ STATUS MSOpGetInputs(ResMgrHandle res_mgr, const NodeHandle op, NodeHandle input
return RET_OK;
}
NodeHandle MSNewSubGraphNode(ResMgrHandle res_mgr, GraphHandle graph, GraphHandle sub_graph, const Handle inputs[],
NodeHandle MSNewFuncCallNode(ResMgrHandle res_mgr, GraphHandle graph, ConstGraphHandle sub_graph, Handle const inputs[],
size_t input_num) {
if (res_mgr == nullptr || graph == nullptr || sub_graph == nullptr || inputs == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [sub_graph] or [inputs] is nullptr.";
if (res_mgr == nullptr || graph == nullptr || sub_graph == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [sub_graph] is nullptr.";
return nullptr;
}
CNodePtr cnode = nullptr;
@ -248,20 +483,22 @@ NodeHandle MSNewSubGraphNode(ResMgrHandle res_mgr, GraphHandle graph, GraphHandl
MS_EXCEPTION_IF_NULL(res_fg);
auto res_sub_fg = GetSrcPtr<FuncGraphPtr>(res_mgr, sub_graph);
MS_EXCEPTION_IF_NULL(res_sub_fg);
auto sub_fg_node = mindspore::NewValueNode(res_sub_fg);
std::vector<AnfNodePtr> cnode_inputs{};
cnode_inputs.push_back(sub_fg_node);
auto sub_node = mindspore::NewValueNode(res_sub_fg);
sub_node->set_abstract(res_sub_fg->ToAbstract());
std::vector<AnfNodePtr> cnode_inputs{sub_node};
for (size_t i = 0; i < input_num; ++i) {
auto cnode_input = GetSrcPtr<AnfNodePtr>(res_mgr, inputs[i]);
MS_EXCEPTION_IF_NULL(cnode_input);
cnode_inputs.push_back(cnode_input);
}
cnode = res_fg->NewCNode(cnode_inputs);
MS_EXCEPTION_IF_NULL(res_sub_fg->output());
cnode->set_abstract(res_sub_fg->output()->abstract());
} catch (const std::exception &e) {
MS_LOG(ERROR) << "FuncGraph create SubGraph node failed. Error info: " << e.what();
return nullptr;
}
MS_LOG(INFO) << "Add subgraph node";
MS_LOG(INFO) << "Add function call node";
return GetRawPtr(res_mgr, cnode);
}
@ -309,7 +546,7 @@ NodeHandle MSNewTensorVariable(ResMgrHandle res_mgr, GraphHandle graph, void *da
return GetRawPtr(res_mgr, param);
}
NodeHandle MSNewTensorVariableFromTensor(ResMgrHandle res_mgr, GraphHandle graph, TensorHandle tensor) {
NodeHandle MSNewTensorVariableFromTensor(ResMgrHandle res_mgr, GraphHandle graph, ConstTensorHandle tensor) {
if (res_mgr == nullptr || graph == nullptr || tensor == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [graph] or [tensor] is nullptr.";
return nullptr;
@ -330,7 +567,7 @@ NodeHandle MSNewTensorVariableFromTensor(ResMgrHandle res_mgr, GraphHandle graph
return GetRawPtr(res_mgr, param);
}
size_t MSTensorVariableGetDataSize(ResMgrHandle res_mgr, NodeHandle node, STATUS *error) {
size_t MSTensorVariableGetDataSize(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
if (error == nullptr) {
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
return 0;
@ -357,7 +594,7 @@ size_t MSTensorVariableGetDataSize(ResMgrHandle res_mgr, NodeHandle node, STATUS
}
}
void *MSTensorVariableGetData(ResMgrHandle res_mgr, NodeHandle node) {
void *MSTensorVariableGetData(ResMgrHandle res_mgr, ConstNodeHandle node) {
if (res_mgr == nullptr || node == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] is nullptr.";
return nullptr;
@ -415,7 +652,7 @@ NodeHandle MSNewTensorConstantFromTensor(ResMgrHandle res_mgr, TensorHandle tens
return GetRawPtr(res_mgr, value_node);
}
size_t MSTensorConstantGetDataSize(ResMgrHandle res_mgr, NodeHandle node, STATUS *error) {
size_t MSTensorConstantGetDataSize(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
if (error == nullptr) {
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
return 0;
@ -442,7 +679,7 @@ size_t MSTensorConstantGetDataSize(ResMgrHandle res_mgr, NodeHandle node, STATUS
}
}
void *MSTensorConstantGetData(ResMgrHandle res_mgr, NodeHandle node) {
void *MSTensorConstantGetData(ResMgrHandle res_mgr, ConstNodeHandle node) {
if (res_mgr == nullptr || node == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] is nullptr.";
return nullptr;
@ -548,7 +785,7 @@ NodeHandle MSNewTypeConstant(ResMgrHandle res_mgr, TypeId type) {
return GetRawPtr(res_mgr, value_node);
}
int MSScalarConstantGetValueInt32(ResMgrHandle res_mgr, const NodeHandle node, STATUS *error) {
int MSScalarConstantGetValueInt32(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
MS_LOG(INFO) << "Get Int32 Scalar Value!";
if (error == nullptr) {
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
@ -585,7 +822,7 @@ int MSScalarConstantGetValueInt32(ResMgrHandle res_mgr, const NodeHandle node, S
return ret_val;
}
float MSScalarConstantGetValueFloat32(ResMgrHandle res_mgr, const NodeHandle node, STATUS *error) {
float MSScalarConstantGetValueFloat32(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
MS_LOG(INFO) << "Get Float32 Scalar Value!";
if (error == nullptr) {
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
@ -622,7 +859,7 @@ float MSScalarConstantGetValueFloat32(ResMgrHandle res_mgr, const NodeHandle nod
return ret_val;
}
bool MSScalarConstantGetValueBool(ResMgrHandle res_mgr, const NodeHandle node, STATUS *error) {
bool MSScalarConstantGetValueBool(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
MS_LOG(INFO) << "Get Bool Scalar Value!";
if (error == nullptr) {
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
@ -659,7 +896,7 @@ bool MSScalarConstantGetValueBool(ResMgrHandle res_mgr, const NodeHandle node, S
return ret_val;
}
int64_t MSScalarConstantGetValueInt64(ResMgrHandle res_mgr, const NodeHandle node, STATUS *error) {
int64_t MSScalarConstantGetValueInt64(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
MS_LOG(INFO) << "Get Int64 Scalar Value!";
if (error == nullptr) {
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
@ -696,7 +933,7 @@ int64_t MSScalarConstantGetValueInt64(ResMgrHandle res_mgr, const NodeHandle nod
return ret_val;
}
STATUS MSStringConstantGetValue(ResMgrHandle res_mgr, const NodeHandle node, char str_buf[], size_t str_len) {
STATUS MSStringConstantGetValue(ResMgrHandle res_mgr, ConstNodeHandle node, char str_buf[], size_t str_len) {
MS_LOG(INFO) << "Get String Constant Value!";
if (res_mgr == nullptr || node == nullptr || str_buf == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] or [str_buf] is nullptr.";
@ -721,7 +958,7 @@ STATUS MSStringConstantGetValue(ResMgrHandle res_mgr, const NodeHandle node, cha
}
}
size_t MSTupleConstantGetSize(ResMgrHandle res_mgr, const NodeHandle node, STATUS *error) {
size_t MSTupleConstantGetSize(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
MS_LOG(INFO) << "Get Tuple Constant size!";
if (error == nullptr) {
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
@ -748,7 +985,7 @@ size_t MSTupleConstantGetSize(ResMgrHandle res_mgr, const NodeHandle node, STATU
}
}
STATUS MSTupleConstantGetValueInt64(ResMgrHandle res_mgr, const NodeHandle node, int64_t vec[], size_t size) {
STATUS MSTupleConstantGetValueInt64(ResMgrHandle res_mgr, ConstNodeHandle node, int64_t vec[], size_t size) {
MS_LOG(INFO) << "Get Tuple Constant Value!";
if (res_mgr == nullptr || node == nullptr || vec == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] or [vec] is nullptr.";
@ -776,7 +1013,7 @@ STATUS MSTupleConstantGetValueInt64(ResMgrHandle res_mgr, const NodeHandle node,
}
}
TypeId MSTypeConstantGetValue(ResMgrHandle res_mgr, const NodeHandle node, STATUS *error) {
TypeId MSTypeConstantGetValue(ResMgrHandle res_mgr, ConstNodeHandle node, STATUS *error) {
MS_LOG(INFO) << "Get Type Constant Value!";
if (error == nullptr) {
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
@ -803,7 +1040,7 @@ TypeId MSTypeConstantGetValue(ResMgrHandle res_mgr, const NodeHandle node, STATU
}
}
STATUS MSOpSetName(ResMgrHandle res_mgr, const NodeHandle node, const char *name) {
STATUS MSOpSetName(ResMgrHandle res_mgr, NodeHandle node, const char *name) {
if (res_mgr == nullptr || node == nullptr || name == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] or [name] is nullptr.";
return RET_NULL_PTR;
@ -817,7 +1054,7 @@ STATUS MSOpSetName(ResMgrHandle res_mgr, const NodeHandle node, const char *name
return RET_OK;
}
STATUS MSNodeGetName(ResMgrHandle res_mgr, const NodeHandle node, char str_buf[], size_t str_len) {
STATUS MSNodeGetName(ResMgrHandle res_mgr, ConstNodeHandle node, char str_buf[], size_t str_len) {
if (res_mgr == nullptr || node == nullptr || str_buf == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [node] or [str_buf] is nullptr.";
return RET_NULL_PTR;

View File

@ -34,6 +34,7 @@ class ResourceManager {
context_ = mindspore::MsContext::GetInstance();
org_policy_ = context_->backend_policy();
context_->set_backend_policy("ms");
context_->set_param<int>(mindspore::MS_CTX_EXECUTION_MODE, mindspore::kGraphMode);
}
~ResourceManager() { context_->set_backend_policy(org_policy_); }
@ -60,7 +61,7 @@ class ResourceManager {
(void)ptr_res_pool_.insert(std::make_pair(reinterpret_cast<Handle>(src_ptr.get()), src_ptr));
}
BasePtr GetSrcPtr(Handle ptr) {
BasePtr GetSrcPtr(ConstHandle ptr) {
auto iter = ptr_res_pool_.find(ptr);
if (iter != ptr_res_pool_.end()) {
return iter->second;
@ -70,7 +71,7 @@ class ResourceManager {
}
}
void ReleaseSrcPtr(Handle ptr) {
void ReleaseSrcPtr(ConstHandle ptr) {
auto iter = ptr_res_pool_.find(ptr);
if (iter != ptr_res_pool_.end()) {
(void)ptr_res_pool_.erase(iter);
@ -78,10 +79,10 @@ class ResourceManager {
}
private:
std::unordered_map<Handle, BasePtr> ptr_res_pool_;
std::unordered_map<ConstHandle, BasePtr> ptr_res_pool_;
mindspore::HashMap<std::string, mindspore::Any> results_{};
std::shared_ptr<mindspore::compile::Backend> backend_ = nullptr;
std::shared_ptr<mindspore::MsContext> context_;
std::shared_ptr<mindspore::MsContext> context_ = nullptr;
std::string org_policy_;
bool auto_infer_;
};

View File

@ -113,7 +113,7 @@ TensorHandle MSNewTensorWithSrcType(ResMgrHandle res_mgr, void *data, const int6
return GetRawPtr(res_mgr, tensor);
}
void *MSTensorGetData(ResMgrHandle res_mgr, const TensorHandle tensor) {
void *MSTensorGetData(ResMgrHandle res_mgr, ConstTensorHandle tensor) {
if (res_mgr == nullptr || tensor == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [tensor] is nullptr.";
return nullptr;
@ -126,7 +126,7 @@ void *MSTensorGetData(ResMgrHandle res_mgr, const TensorHandle tensor) {
return src_tensor->data_c();
}
STATUS MSTensorSetDataType(ResMgrHandle res_mgr, const TensorHandle tensor, TypeId type) {
STATUS MSTensorSetDataType(ResMgrHandle res_mgr, TensorHandle tensor, TypeId type) {
if (res_mgr == nullptr || tensor == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [tensor] is nullptr.";
return RET_ERROR;
@ -140,7 +140,7 @@ STATUS MSTensorSetDataType(ResMgrHandle res_mgr, const TensorHandle tensor, Type
return RET_OK;
}
TypeId MSTensorGetDataType(ResMgrHandle res_mgr, const TensorHandle tensor, STATUS *error) {
TypeId MSTensorGetDataType(ResMgrHandle res_mgr, ConstTensorHandle tensor, STATUS *error) {
if (error == nullptr) {
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
return (enum TypeId)0;
@ -160,7 +160,7 @@ TypeId MSTensorGetDataType(ResMgrHandle res_mgr, const TensorHandle tensor, STAT
return (enum TypeId)(src_tensor->data_type_c());
}
size_t MSTensorGetDataSize(ResMgrHandle res_mgr, const TensorHandle tensor, STATUS *error) {
size_t MSTensorGetDataSize(ResMgrHandle res_mgr, ConstTensorHandle tensor, STATUS *error) {
if (error == nullptr) {
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
return 0;
@ -181,7 +181,7 @@ size_t MSTensorGetDataSize(ResMgrHandle res_mgr, const TensorHandle tensor, STAT
return size;
}
size_t MSTensorGetElementNum(ResMgrHandle res_mgr, const TensorHandle tensor, STATUS *error) {
size_t MSTensorGetElementNum(ResMgrHandle res_mgr, ConstTensorHandle tensor, STATUS *error) {
if (error == nullptr) {
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
return 0;
@ -202,7 +202,7 @@ size_t MSTensorGetElementNum(ResMgrHandle res_mgr, const TensorHandle tensor, ST
return ele_num;
}
size_t MSTensorGetDimension(ResMgrHandle res_mgr, const TensorHandle tensor, STATUS *error) {
size_t MSTensorGetDimension(ResMgrHandle res_mgr, ConstTensorHandle tensor, STATUS *error) {
if (error == nullptr) {
MS_LOG(ERROR) << "Input status flag [error] is nullptr.";
return 0;
@ -223,7 +223,7 @@ size_t MSTensorGetDimension(ResMgrHandle res_mgr, const TensorHandle tensor, STA
return dim;
}
STATUS MSTensorSetShape(ResMgrHandle res_mgr, const TensorHandle tensor, int64_t shape[], size_t dim) {
STATUS MSTensorSetShape(ResMgrHandle res_mgr, TensorHandle tensor, const int64_t shape[], size_t dim) {
if (res_mgr == nullptr || tensor == nullptr || shape == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [tensor] or [shape] is nullptr.";
return RET_NULL_PTR;
@ -243,7 +243,7 @@ STATUS MSTensorSetShape(ResMgrHandle res_mgr, const TensorHandle tensor, int64_t
return RET_OK;
}
STATUS MSTensorGetShape(ResMgrHandle res_mgr, const TensorHandle tensor, int64_t shape[], size_t dim) {
STATUS MSTensorGetShape(ResMgrHandle res_mgr, ConstTensorHandle tensor, int64_t shape[], size_t dim) {
if (res_mgr == nullptr || tensor == nullptr || shape == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [tensor] or [shape] is nullptr.";
return RET_NULL_PTR;