!48489 [MS]c_api support bool attr, tensor from file, save ir path

Merge pull request !48489 from liyejun/c_api_hl
This commit is contained in:
i-robot 2023-02-09 06:39:27 +00:00 committed by Gitee
commit 3ded5cf3e3
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 123 additions and 25 deletions

View File

@ -145,8 +145,8 @@ MIND_C_API int64_t MSOpGetScalarAttrInt64(ResMgrHandle res_mgr, NodeHandle op, c
/// \param[in] value_num Size of the given array.
///
/// \return Error code indicates whether the function executed successfully.
MIND_C_API STATUS MSOpGetAttrsInt64(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, int64_t values[],
size_t value_num);
MIND_C_API STATUS MSOpGetAttrArrayInt64(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, int64_t values[],
size_t value_num);
/// \brief Create new Int64 attribute scalar value.
///
@ -164,6 +164,14 @@ MIND_C_API AttrHandle MSNewAttrInt64(ResMgrHandle res_mgr, const int64_t v);
/// \return Attribute value handle.
MIND_C_API AttrHandle MSNewAttrFloat32(ResMgrHandle res_mgr, const float v);
/// \brief Create new Bool attribute scalar value.
///
/// \param[in] res_mgr Resource Handle that manages the nodes of the funcGraph.
/// \param[in] v Given value.
///
/// \return Attribute value handle.
MIND_C_API AttrHandle MSNewAttrBool(ResMgrHandle res_mgr, const bool v);
/// \brief Create new attribute value with array.
///
/// \param[in] res_mgr Resource Handle that manages the nodes of the funcGraph.

View File

@ -72,6 +72,8 @@ MIND_C_API void MSSetDeviceId(uint32_t deviceId);
/// 3: Full level. Save all IR graphs.
MIND_C_API void MSSetGraphsSaveMode(int save_mode);
MIND_C_API void MSSetGraphsSavePath(const char *save_path);
/// \brief Set flag for auto shape and type infer
///
/// \param res_mgr Resource Handle that manages the nodes of the funcGraph.

View File

@ -42,6 +42,18 @@ extern "C" {
MIND_C_API TensorHandle MSNewTensor(ResMgrHandle res_mgr, void *data, TypeId type, const int64_t shape[],
size_t shape_size, size_t data_len);
/// \brief Create a tensor with path to a space-sperated txt file.
///
/// \param[in] res_mgr Resource manager that saves allocated instance resources.
/// \param[in] type [TypeId] Data type of the tensor.
/// \param[in] shape The shape arary of the tensor.
/// \param[in] shape_size The size of shape array, i.e., the rank of the tensor.
/// \param[in] path path to the file.
///
/// \return The pointer of the created tensor instance.
MIND_C_API TensorHandle MSNewTensorFromFile(ResMgrHandle res_mgr, TypeId type, const int64_t shape[], size_t shape_size,
const char *path);
/// \brief Create a tensor with input data buffer and given source data type.
///
/// \param[in] res_mgr Resource manager that saves allocated instance resources.

View File

@ -240,8 +240,8 @@ int64_t MSOpGetScalarAttrInt64(ResMgrHandle res_mgr, NodeHandle op, const char *
}
}
STATUS MSOpGetAttrsInt64(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, int64_t values[],
size_t value_num) {
STATUS MSOpGetAttrArrayInt64(ResMgrHandle res_mgr, NodeHandle op, const char *attr_name, int64_t values[],
size_t value_num) {
if (res_mgr == nullptr || op == nullptr || attr_name == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [op] or [attr_name] is nullptr.";
return RET_NULL_PTR;
@ -288,6 +288,15 @@ AttrHandle MSNewAttrFloat32(ResMgrHandle res_mgr, const float v) {
return GetRawPtr(res_mgr, value);
}
AttrHandle MSNewAttrBool(ResMgrHandle res_mgr, const bool v) {
if (res_mgr == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] is nullptr.";
return nullptr;
}
auto value = std::make_shared<BoolImmImpl>(v);
return GetRawPtr(res_mgr, value);
}
AttrHandle MSOpNewAttrs(ResMgrHandle res_mgr, void *value, size_t vec_size, TypeId data_type) {
if (res_mgr == nullptr || value == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [value_vec] is nullptr.";

View File

@ -85,6 +85,13 @@ void MSSetGraphsSaveMode(int save_mode) {
return;
}
void MSSetGraphsSavePath(const char *save_path) {
MS_LOG(DEBUG) << "Set Graphs Save Path: " << save_path;
auto context = mindspore::MsContext::GetInstance();
context->set_param<std::string>(mindspore::MS_CTX_SAVE_GRAPHS_PATH, save_path);
return;
}
void MSSetInfer(ResMgrHandle res_mgr, bool infer) {
MS_LOG(DEBUG) << "Set Infer Graph: " << infer;
auto res_mgr_ptr = reinterpret_cast<ResourceManager *>(res_mgr);

View File

@ -21,6 +21,24 @@
#include "ir/tensor.h"
#include "ir/dtype.h"
template <typename T>
void GetDataByFile(std::vector<T> data, const char *path, size_t *elem_size) {
std::string fileName(path);
MS_LOG(INFO) << "Reading File: " << fileName << std::endl;
std::ifstream fin(fileName, std::ios::in);
if (!fin.is_open()) {
MS_LOG(ERROR) << "Open file failed, File path: %s " << fileName << std::endl;
return;
}
T t;
while (fin >> t) {
data.push_back(t);
}
fin.close();
*elem_size = data.size() * sizeof(T);
return;
}
TensorHandle MSNewTensor(ResMgrHandle res_mgr, void *data, TypeId type, const int64_t shape[], size_t shape_size,
size_t data_len) {
if (res_mgr == nullptr || data == nullptr || shape == nullptr) {
@ -38,6 +56,46 @@ TensorHandle MSNewTensor(ResMgrHandle res_mgr, void *data, TypeId type, const in
return GetRawPtr(res_mgr, tensor);
}
TensorHandle MSNewTensorFromFile(ResMgrHandle res_mgr, TypeId type, const int64_t shape[], size_t shape_size,
const char *path) {
if (res_mgr == nullptr || shape == nullptr) {
MS_LOG(ERROR) << "Input Handle [res_mgr] or [shape] is nullptr.";
return nullptr;
}
TensorPtr tensor = nullptr;
ShapeVector shape_vec(shape, shape + shape_size);
try {
size_t data_len;
switch (type) {
case TypeId::kNumberTypeInt32: {
std::vector<int32_t> data;
(void)GetDataByFile<int32_t>(data, path, &data_len);
tensor = std::make_shared<TensorImpl>(mindspore::TypeId(type), shape_vec, data.data(), data_len);
break;
}
case TypeId::kNumberTypeInt64: {
std::vector<int64_t> data;
(void)GetDataByFile<int64_t>(data, path, &data_len);
tensor = std::make_shared<TensorImpl>(mindspore::TypeId(type), shape_vec, data.data(), data_len);
break;
}
case TypeId::kNumberTypeFloat32: {
std::vector<float> data;
(void)GetDataByFile<float>(data, path, &data_len);
tensor = std::make_shared<TensorImpl>(mindspore::TypeId(type), shape_vec, data.data(), data_len);
break;
}
default:
MS_LOG(ERROR) << "Unrecognized datatype w/ TypeId: " << type << std::endl;
return nullptr;
}
} catch (const std::exception &e) {
MS_LOG(ERROR) << "New Tensor failed. Error info: " << e.what();
return nullptr;
}
return GetRawPtr(res_mgr, tensor);
}
TensorHandle MSNewTensorWithSrcType(ResMgrHandle res_mgr, void *data, const int64_t shape[], size_t shape_size,
TypeId tensor_type, TypeId src_type) {
if (res_mgr == nullptr || data == nullptr || shape == nullptr) {

View File

@ -26,7 +26,7 @@
#include "c_api/base/handle_types.h"
namespace mindspore {
class TestCApiAttr: public UT::Common {
class TestCApiAttr : public UT::Common {
public:
TestCApiAttr() = default;
};
@ -43,13 +43,13 @@ TEST_F(TestCApiAttr, test_attr) {
AttrHandle attr1 = MSNewAttrInt64(res_mgr, 1);
ASSERT_TRUE(attr1 != nullptr);
int64_t attr2_raw[] = {2,2};
AttrHandle attr2 = MSOpNewAttrs(res_mgr,attr2_raw,2,kNumberTypeInt64);
int64_t attr2_raw[] = {2, 2};
AttrHandle attr2 = MSOpNewAttrs(res_mgr, attr2_raw, 2, kNumberTypeInt64);
ASSERT_TRUE(attr2 != nullptr);
char name1[] = "attr1";
char name2[] = "attr2";
char *attr_names[] = {name1,name2};
AttrHandle attrs[] = {attr1,attr2};
char *attr_names[] = {name1, name2};
AttrHandle attrs[] = {attr1, attr2};
size_t attr_num = 2;
NodeHandle x = MSNewPlaceholder(res_mgr, fg, kNumberTypeInt32, NULL, 0);
@ -60,26 +60,28 @@ TEST_F(TestCApiAttr, test_attr) {
size_t input_num = 2;
NodeHandle op_add = MSNewOp(res_mgr, fg, "Add", input_nodes, input_num, attr_names, attrs, attr_num);
ASSERT_TRUE(op_add != nullptr);
int64_t attr1_retrived = MSOpGetScalarAttrInt64(res_mgr,op_add,"attr1",&ret);
int64_t attr1_retrived = MSOpGetScalarAttrInt64(res_mgr, op_add, "attr1", &ret);
ASSERT_EQ(ret, RET_OK);
ASSERT_EQ(attr1_retrived,1);
ASSERT_EQ(attr1_retrived, 1);
int64_t values[2];
ret = MSOpGetAttrsInt64(res_mgr,op_add,"attr2",values,2);
ASSERT_EQ(ret,RET_OK);
ASSERT_EQ(values[0],2);
ASSERT_EQ(values[1],2);
ret = MSOpSetScalarAttrInt64(res_mgr, op_add, "attr1", 2);
ASSERT_EQ(ret,RET_OK);
attr1_retrived = MSOpGetScalarAttrInt64(res_mgr,op_add,"attr1",&ret);
ret = MSOpGetAttrArrayInt64(res_mgr, op_add, "attr2", values, 2);
ASSERT_EQ(ret, RET_OK);
ASSERT_EQ(attr1_retrived,2);
ASSERT_EQ(values[0], 2);
ASSERT_EQ(values[1], 2);
ret = MSOpSetScalarAttrInt64(res_mgr, op_add, "attr1", 2);
ASSERT_EQ(ret, RET_OK);
attr1_retrived = MSOpGetScalarAttrInt64(res_mgr, op_add, "attr1", &ret);
ASSERT_EQ(ret, RET_OK);
ASSERT_EQ(attr1_retrived, 2);
values[0] = 1;
values[1] = 1;
ret = MSOpSetAttrArray(res_mgr,op_add,"attr2",values,2,kNumberTypeInt64);
ASSERT_EQ(ret,RET_OK);
ASSERT_EQ(values[0],1);
ASSERT_EQ(values[1],1);
ret = MSOpSetAttrArray(res_mgr, op_add, "attr2", values, 2, kNumberTypeInt64);
ASSERT_EQ(ret, RET_OK);
ret = MSOpGetAttrArrayInt64(res_mgr, op_add, "attr2", values, 2);
ASSERT_EQ(ret, RET_OK);
ASSERT_EQ(values[0], 1);
ASSERT_EQ(values[1], 1);
MSResourceManagerDestroy(res_mgr);
}
}
} // namespace mindspore