!36337 Add Graph and InMemoryGraphDataset for loading graph
Merge pull request !36337 from ms_yan/graph
This commit is contained in:
commit
6b42609608
|
@ -67,6 +67,7 @@
|
|||
"mindspore/mindspore/python/mindspore/dataset/engine/__init__.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/dataset/engine/datasets.py" "redefined-builtin"
|
||||
"mindspore/mindspore/python/mindspore/dataset/engine/datasets.py" "broad-except"
|
||||
"mindspore/mindspore/python/mindspore/dataset/engine/graphdata.py" "super-init-not-called"
|
||||
"mindspore/mindspore/python/mindspore/dataset/transforms/py_transforms_util.py" "broad-except"
|
||||
|
||||
# Tests
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/stl_bind.h"
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_conversion.h"
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/engine/gnn/graph_data_client.h"
|
||||
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
|
||||
|
@ -23,20 +24,46 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
const char kInvalidPath[] = "invalid_dataset_file_path";
|
||||
using FeatureType = std::int16_t;
|
||||
|
||||
PYBIND_REGISTER(
|
||||
Graph, 0, ([](const py::module *m) {
|
||||
(void)py::class_<gnn::GraphData, std::shared_ptr<gnn::GraphData>>(*m, "GraphDataClient")
|
||||
.def(py::init([](const std::string &dataset_file, int32_t num_workers, const std::string &working_mode,
|
||||
const std::string &hostname, int32_t port) {
|
||||
.def(py::init([](const std::string &data_format, const std::string &dataset_file, int32_t num_workers,
|
||||
const std::string &working_mode, const std::string &hostname, int32_t port) {
|
||||
std::shared_ptr<gnn::GraphData> out;
|
||||
if (working_mode == "local") {
|
||||
out = std::make_shared<gnn::GraphDataImpl>(dataset_file, num_workers);
|
||||
out = std::make_shared<gnn::GraphDataImpl>(data_format, dataset_file, num_workers);
|
||||
} else if (working_mode == "client") {
|
||||
out = std::make_shared<gnn::GraphDataClient>(dataset_file, hostname, port);
|
||||
}
|
||||
THROW_IF_ERROR(out->Init());
|
||||
return out;
|
||||
}))
|
||||
.def(py::init([](const std::string &data_format, int32_t num_nodes, const py::array &edges,
|
||||
const py::dict &node_feat, const py::dict &edge_feat, const py::dict &graph_feat,
|
||||
const py::array &node_type, const py::array &edge_type, int32_t num_workers,
|
||||
const std::string &working_mode, const std::string &hostname, int32_t port) {
|
||||
std::shared_ptr<gnn::GraphData> out;
|
||||
std::shared_ptr<Tensor> edge_tensor, node_type_tensor, edge_type_tensor;
|
||||
std::unordered_map<FeatureType, std::shared_ptr<Tensor>> node_feat_map, edge_feat_map, graph_feat_map;
|
||||
|
||||
THROW_IF_ERROR(convertNumpyData(edges, node_feat, edge_feat, graph_feat, node_type, edge_type, &edge_tensor,
|
||||
&node_feat_map, &edge_feat_map, &graph_feat_map, &node_type_tensor,
|
||||
&edge_type_tensor));
|
||||
|
||||
if (working_mode == "local") {
|
||||
out = std::make_shared<gnn::GraphDataImpl>(data_format, kInvalidPath, num_workers);
|
||||
THROW_IF_ERROR(out->Init(std::move(num_nodes), std::move(edge_tensor), std::move(node_feat_map),
|
||||
std::move(edge_feat_map), std::move(graph_feat_map), std::move(node_type_tensor),
|
||||
std::move(edge_type_tensor)));
|
||||
} else if (working_mode == "client") {
|
||||
out = std::make_shared<gnn::GraphDataClient>(kInvalidPath, hostname, port);
|
||||
THROW_IF_ERROR(out->Init());
|
||||
}
|
||||
return out;
|
||||
}))
|
||||
.def("get_all_nodes",
|
||||
[](gnn::GraphData &g, gnn::NodeType node_type) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
|
@ -97,6 +124,12 @@ PYBIND_REGISTER(
|
|||
THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out));
|
||||
return out.getRow();
|
||||
})
|
||||
.def("get_graph_feature",
|
||||
[](gnn::GraphData &g, std::vector<gnn::FeatureType> feature_types) {
|
||||
TensorRow out;
|
||||
THROW_IF_ERROR(g.GetGraphFeature(feature_types, &out));
|
||||
return out.getRow();
|
||||
})
|
||||
.def("graph_info",
|
||||
[](gnn::GraphData &g) {
|
||||
py::dict out;
|
||||
|
@ -114,14 +147,33 @@ PYBIND_REGISTER(
|
|||
.def("stop", [](gnn::GraphData &g) { THROW_IF_ERROR(g.Stop()); });
|
||||
|
||||
(void)py::class_<gnn::GraphDataServer, std::shared_ptr<gnn::GraphDataServer>>(*m, "GraphDataServer")
|
||||
.def(py::init([](const std::string &dataset_file, int32_t num_workers, const std::string &hostname, int32_t port,
|
||||
int32_t client_num, bool auto_shutdown) {
|
||||
.def(py::init([](const std::string &data_format, const std::string &dataset_file, int32_t num_workers,
|
||||
const std::string &hostname, int32_t port, int32_t client_num, bool auto_shutdown) {
|
||||
std::shared_ptr<gnn::GraphDataServer> out;
|
||||
out =
|
||||
std::make_shared<gnn::GraphDataServer>(dataset_file, num_workers, hostname, port, client_num, auto_shutdown);
|
||||
out = std::make_shared<gnn::GraphDataServer>(data_format, dataset_file, num_workers, hostname, port, client_num,
|
||||
auto_shutdown);
|
||||
THROW_IF_ERROR(out->Init());
|
||||
return out;
|
||||
}))
|
||||
.def(py::init([](const std::string &data_format, int32_t num_nodes, const py::array &edges,
|
||||
const py::dict &node_feat, const py::dict &edge_feat, const py::dict &graph_feat,
|
||||
const py::array &node_type, const py::array &edge_type, int32_t num_workers,
|
||||
const std::string &hostname, int32_t port, int32_t client_num, bool auto_shutdown) {
|
||||
std::shared_ptr<gnn::GraphDataServer> out;
|
||||
std::shared_ptr<Tensor> edge_tensor, node_type_tensor, edge_type_tensor;
|
||||
std::unordered_map<FeatureType, std::shared_ptr<Tensor>> node_feat_map, edge_feat_map, graph_feat_map;
|
||||
|
||||
THROW_IF_ERROR(convertNumpyData(edges, node_feat, edge_feat, graph_feat, node_type, edge_type, &edge_tensor,
|
||||
&node_feat_map, &edge_feat_map, &graph_feat_map, &node_type_tensor,
|
||||
&edge_type_tensor));
|
||||
|
||||
out = std::make_shared<gnn::GraphDataServer>(data_format, kInvalidPath, num_workers, hostname, port, client_num,
|
||||
auto_shutdown);
|
||||
THROW_IF_ERROR(out->Init(std::move(num_nodes), std::move(edge_tensor), std::move(node_feat_map),
|
||||
std::move(edge_feat_map), std::move(graph_feat_map), std::move(node_type_tensor),
|
||||
std::move(edge_type_tensor)));
|
||||
return out;
|
||||
}))
|
||||
.def("stop", [](gnn::GraphDataServer &g) { THROW_IF_ERROR(g.Stop()); })
|
||||
.def("is_stopped", [](gnn::GraphDataServer &g) { return g.IsStopped(); });
|
||||
}));
|
||||
|
|
|
@ -318,5 +318,41 @@ py::list typesToListOfType(std::vector<DataType> types) {
|
|||
}
|
||||
return type_list;
|
||||
}
|
||||
|
||||
Status toIntMapTensor(py::dict value, std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> *feature) {
|
||||
RETURN_UNEXPECTED_IF_NULL(feature);
|
||||
for (const auto &p : value) {
|
||||
// do some judge, as whether it is none
|
||||
std::shared_ptr<Tensor> feat_tensor = nullptr;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(py::reinterpret_borrow<py::array>(p.second), &feat_tensor));
|
||||
(void)feature->insert({toInt(p.first), feat_tensor});
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status convertNumpyData(const py::array &edges, const py::dict &node_feat, const py::dict &edge_feat,
|
||||
const py::dict &graph_feat, const py::array &node_type, const py::array &edge_type,
|
||||
std::shared_ptr<Tensor> *edge_tensor,
|
||||
std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> *node_feat_map,
|
||||
std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> *edge_feat_map,
|
||||
std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> *graph_feat_map,
|
||||
std::shared_ptr<Tensor> *node_type_tensor, std::shared_ptr<Tensor> *edge_type_tensor) {
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(edges, edge_tensor));
|
||||
if (!node_feat.empty()) {
|
||||
RETURN_IF_NOT_OK(toIntMapTensor(node_feat, node_feat_map));
|
||||
}
|
||||
|
||||
if (!edge_feat.empty()) {
|
||||
RETURN_IF_NOT_OK(toIntMapTensor(edge_feat, edge_feat_map));
|
||||
}
|
||||
|
||||
if (!graph_feat.empty()) {
|
||||
RETURN_IF_NOT_OK(toIntMapTensor(graph_feat, graph_feat_map));
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(node_type, node_type_tensor));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(edge_type, edge_type_tensor));
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -92,6 +92,16 @@ py::list shapesToListOfShape(std::vector<TensorShape> shapes);
|
|||
|
||||
py::list typesToListOfType(std::vector<DataType> types);
|
||||
|
||||
Status toIntMapTensor(py::dict value, std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> *feature);
|
||||
|
||||
// abstract similar logic in gnn bindings part into one function
|
||||
Status convertNumpyData(const py::array &edges, const py::dict &node_feat, const py::dict &edge_feat,
|
||||
const py::dict &graph_feat, const py::array &node_type, const py::array &edge_type,
|
||||
std::shared_ptr<Tensor> *edge_tensor,
|
||||
std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> *node_feat_map,
|
||||
std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> *edge_feat_map,
|
||||
std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> *graph_feat_map,
|
||||
std::shared_ptr<Tensor> *node_type_tensor, std::shared_ptr<Tensor> *edge_type_tensor);
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_API_PYTHON_PYBIND_CONVERSION_H_
|
||||
|
|
|
@ -5,6 +5,7 @@ set(DATASET_ENGINE_GNN_SRC_FILES
|
|||
graph_data_client.cc
|
||||
graph_data_server.cc
|
||||
graph_loader.cc
|
||||
graph_loader_array.cc
|
||||
graph_feature_parser.cc
|
||||
local_node.cc
|
||||
local_edge.cc
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -35,11 +35,10 @@ class Edge {
|
|||
// @param EdgeIdType id - edge id
|
||||
// @param EdgeType type - edge type
|
||||
// @param WeightType weight - edge weight
|
||||
// @param std::shared_ptr<Node> src_node - source node
|
||||
// @param std::shared_ptr<Node> dst_node - destination node
|
||||
Edge(EdgeIdType id, EdgeType type, WeightType weight, const std::shared_ptr<Node> &src_node,
|
||||
const std::shared_ptr<Node> &dst_node)
|
||||
: id_(id), type_(type), weight_(weight), src_node_(src_node), dst_node_(dst_node) {}
|
||||
// @param NodeIdType src_id - source node id
|
||||
// @param NodeIdType dst_id - destination node id
|
||||
Edge(EdgeIdType id, EdgeType type, WeightType weight, NodeIdType src_id, NodeIdType dst_id)
|
||||
: id_(id), type_(type), weight_(weight), src_id_(src_id), dst_id_(dst_id) {}
|
||||
|
||||
virtual ~Edge() = default;
|
||||
|
||||
|
@ -59,18 +58,20 @@ class Edge {
|
|||
virtual Status GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) = 0;
|
||||
|
||||
// Get nodes on the edge
|
||||
// @param std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> *out_node - Source and destination nodes returned
|
||||
Status GetNode(std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> *out_node) {
|
||||
RETURN_UNEXPECTED_IF_NULL(out_node);
|
||||
*out_node = std::make_pair(src_node_, dst_node_);
|
||||
// @param NodeIdType *src_id - Source node id returned
|
||||
// @param NodeIdType *dst_id - Destination node id returned
|
||||
Status GetNode(NodeIdType *src_id, NodeIdType *dst_id) {
|
||||
*src_id = src_id_;
|
||||
*dst_id = dst_id_;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Set node to edge
|
||||
// @param const std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> &in_node -
|
||||
Status SetNode(const std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> &in_node) {
|
||||
src_node_ = in_node.first;
|
||||
dst_node_ = in_node.second;
|
||||
// @param NodeIdType src_id - Source node id
|
||||
// @param NodeIdType dst_id - Destination node id
|
||||
Status SetNode(NodeIdType src_id, NodeIdType dst_id) {
|
||||
src_id_ = src_id;
|
||||
dst_id_ = dst_id;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -83,8 +84,8 @@ class Edge {
|
|||
EdgeIdType id_;
|
||||
EdgeType type_;
|
||||
WeightType weight_;
|
||||
std::shared_ptr<Node> src_node_;
|
||||
std::shared_ptr<Node> dst_node_;
|
||||
NodeIdType src_id_;
|
||||
NodeIdType dst_id_;
|
||||
};
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -36,6 +36,7 @@ message GnnClientRegisterResponsePb {
|
|||
int64 shared_memory_size = 4;
|
||||
repeated GnnFeatureInfoPb default_node_feature = 5;
|
||||
repeated GnnFeatureInfoPb default_edge_feature = 6;
|
||||
repeated GnnFeatureInfoPb graph_feature = 7;
|
||||
}
|
||||
|
||||
message GnnClientUnRegisterRequestPb {
|
||||
|
@ -102,6 +103,7 @@ message GnnMetaInfoResponsePb {
|
|||
repeated GnnNodeEdgeInfoPb edge_info = 3;
|
||||
repeated int32 node_feature_type = 4;
|
||||
repeated int32 edge_feature_type = 5;
|
||||
repeated int32 graph_feature_type = 6;
|
||||
}
|
||||
|
||||
service GnnGraphData {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -20,6 +20,7 @@
|
|||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
|
@ -40,6 +41,7 @@ struct MetaInfo {
|
|||
std::map<EdgeType, EdgeIdType> edge_num;
|
||||
std::vector<FeatureType> node_feature_type;
|
||||
std::vector<FeatureType> edge_feature_type;
|
||||
std::vector<FeatureType> graph_feature_type;
|
||||
};
|
||||
|
||||
class GraphData {
|
||||
|
@ -119,7 +121,7 @@ class GraphData {
|
|||
// Get the feature of a node
|
||||
// @param std::shared_ptr<Tensor> nodes - List of nodes
|
||||
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
||||
// does not exist.
|
||||
// does not exist.
|
||||
// @param TensorRow *out - Returned features
|
||||
// @return Status The status code returned
|
||||
virtual Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
|
||||
|
@ -128,17 +130,32 @@ class GraphData {
|
|||
// Get the feature of a edge
|
||||
// @param std::shared_ptr<Tensor> edges - List of edges
|
||||
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
||||
// does not exist.
|
||||
// does not exist.
|
||||
// @param Tensor *out - Returned features
|
||||
// @return Status The status code returned
|
||||
virtual Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
|
||||
TensorRow *out) = 0;
|
||||
|
||||
// Get the feature in graph level
|
||||
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
||||
// does not exist.
|
||||
// @param Tensor *out - Returned features
|
||||
// @return Status The status code returned
|
||||
virtual Status GetGraphFeature(const std::vector<FeatureType> &feature_types, TensorRow *out) = 0;
|
||||
|
||||
// Return meta information to python layer
|
||||
virtual Status GraphInfo(py::dict *out) = 0;
|
||||
|
||||
virtual Status Init() = 0;
|
||||
|
||||
virtual Status Init(int32_t num_nodes, const std::shared_ptr<Tensor> &edge,
|
||||
const std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> &node_feat,
|
||||
const std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> &edge_feat,
|
||||
const std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> &graph_feat,
|
||||
const std::shared_ptr<Tensor> &node_type, const std::shared_ptr<Tensor> &edge_type) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
virtual Status Stop() = 0;
|
||||
};
|
||||
} // namespace gnn
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -50,7 +50,7 @@ GraphDataClient::~GraphDataClient() { (void)Stop(); }
|
|||
|
||||
Status GraphDataClient::Init() {
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
RETURN_STATUS_UNEXPECTED("Graph data client is not supported in Windows OS");
|
||||
RETURN_STATUS_UNEXPECTED("Graph data client is not supported in Windows OS.");
|
||||
#else
|
||||
if (!registered_) {
|
||||
std::string server_address;
|
||||
|
@ -81,6 +81,8 @@ Status GraphDataClient::Stop() {
|
|||
if (registered_) {
|
||||
RETURN_IF_NOT_OK(UnRegisterToServer());
|
||||
}
|
||||
#else
|
||||
RETURN_STATUS_UNEXPECTED("Graph data client is not supported in Windows OS.");
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -93,6 +95,8 @@ Status GraphDataClient::GetAllNodes(NodeType node_type, std::shared_ptr<Tensor>
|
|||
request.set_op_name(GET_ALL_NODES);
|
||||
request.add_type(static_cast<google::protobuf::int32>(node_type));
|
||||
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
|
||||
#else
|
||||
RETURN_STATUS_UNEXPECTED("This operation is not supported in Windows OS.");
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -105,6 +109,8 @@ Status GraphDataClient::GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor>
|
|||
request.set_op_name(GET_ALL_EDGES);
|
||||
request.add_type(static_cast<google::protobuf::int32>(edge_type));
|
||||
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
|
||||
#else
|
||||
RETURN_STATUS_UNEXPECTED("This operation is not supported in Windows OS.");
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -119,6 +125,8 @@ Status GraphDataClient::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_li
|
|||
request.add_id(static_cast<google::protobuf::int32>(edge_id));
|
||||
}
|
||||
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
|
||||
#else
|
||||
RETURN_STATUS_UNEXPECTED("This operation is not supported in Windows OS.");
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -139,6 +147,8 @@ Status GraphDataClient::GetEdgesFromNodes(const std::vector<std::pair<NodeIdType
|
|||
}
|
||||
|
||||
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
|
||||
#else
|
||||
RETURN_STATUS_UNEXPECTED("This operation is not supported in Windows OS.");
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -156,6 +166,8 @@ Status GraphDataClient::GetAllNeighbors(const std::vector<NodeIdType> &node_list
|
|||
request.add_type(static_cast<google::protobuf::int32>(neighbor_type));
|
||||
request.set_format(static_cast<google::protobuf::int32>(format));
|
||||
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
|
||||
#else
|
||||
RETURN_STATUS_UNEXPECTED("This operation is not supported in Windows OS.");
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -180,6 +192,8 @@ Status GraphDataClient::GetSampledNeighbors(const std::vector<NodeIdType> &node_
|
|||
}
|
||||
request.set_strategy(static_cast<google::protobuf::int32>(strategy));
|
||||
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
|
||||
#else
|
||||
RETURN_STATUS_UNEXPECTED("This operation is not supported in Windows OS.");
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -197,6 +211,8 @@ Status GraphDataClient::GetNegSampledNeighbors(const std::vector<NodeIdType> &no
|
|||
request.add_number(static_cast<google::protobuf::int32>(samples_num));
|
||||
request.add_type(static_cast<google::protobuf::int32>(neg_neighbor_type));
|
||||
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
|
||||
#else
|
||||
RETURN_STATUS_UNEXPECTED("This operation is not supported in Windows OS.");
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -221,6 +237,8 @@ Status GraphDataClient::GraphDataClient::RandomWalk(const std::vector<NodeIdType
|
|||
walk_param->set_q(step_away_param);
|
||||
walk_param->set_default_id(static_cast<google::protobuf::int32>(default_node));
|
||||
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
|
||||
#else
|
||||
RETURN_STATUS_UNEXPECTED("This operation is not supported in Windows OS.");
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -257,6 +275,8 @@ Status GraphDataClient::GetNodeFeature(const std::shared_ptr<Tensor> &nodes,
|
|||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("RPC failed: The number of returned tensor is abnormal");
|
||||
}
|
||||
#else
|
||||
RETURN_STATUS_UNEXPECTED("This operation is not supported in Windows OS.");
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -293,6 +313,23 @@ Status GraphDataClient::GetEdgeFeature(const std::shared_ptr<Tensor> &edges,
|
|||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("RPC failed: The number of returned tensor is abnormal");
|
||||
}
|
||||
#else
|
||||
RETURN_STATUS_UNEXPECTED("This operation is not supported in Windows OS.");
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphDataClient::GetGraphFeature(const std::vector<FeatureType> &feature_types, TensorRow *out) {
|
||||
RETURN_UNEXPECTED_IF_NULL(out);
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Input feature_types is empty.");
|
||||
for (auto i = 0; i < feature_types.size(); i++) {
|
||||
std::shared_ptr<Tensor> fea_tensor;
|
||||
RETURN_IF_NOT_OK(GetStoredGraphFeature(feature_types[i], &fea_tensor));
|
||||
out->emplace_back(std::move(fea_tensor));
|
||||
}
|
||||
#else
|
||||
RETURN_STATUS_UNEXPECTED("This operation is not supported in Windows OS.");
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -343,17 +380,23 @@ Status GraphDataClient::GraphInfo(py::dict *out) {
|
|||
for (const auto &feature_type : response.edge_feature_type()) {
|
||||
meta_info.edge_feature_type.emplace_back(static_cast<FeatureType>(feature_type));
|
||||
}
|
||||
for (const auto &feature_type : response.graph_feature_type()) {
|
||||
meta_info.graph_feature_type.emplace_back(static_cast<FeatureType>(feature_type));
|
||||
}
|
||||
(*out)["node_type"] = py::cast(meta_info.node_type);
|
||||
(*out)["edge_type"] = py::cast(meta_info.edge_type);
|
||||
(*out)["node_num"] = py::cast(meta_info.node_num);
|
||||
(*out)["edge_num"] = py::cast(meta_info.edge_num);
|
||||
(*out)["node_feature_type"] = py::cast(meta_info.node_feature_type);
|
||||
(*out)["edge_feature_type"] = py::cast(meta_info.edge_feature_type);
|
||||
(*out)["graph_feature_type"] = py::cast(meta_info.graph_feature_type);
|
||||
}
|
||||
} else {
|
||||
auto error_code = status.error_code();
|
||||
RETURN_STATUS_UNEXPECTED(status.error_message() + ". GRPC Code " + std::to_string(error_code));
|
||||
}
|
||||
#else
|
||||
RETURN_STATUS_UNEXPECTED("This operation is not supported in Windows OS.");
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -514,6 +557,16 @@ Status GraphDataClient::GetEdgeDefaultFeature(FeatureType feature_type, std::sha
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphDataClient::GetStoredGraphFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature) {
|
||||
auto itr = graph_feature_map_.find(feature_type);
|
||||
if (itr == graph_feature_map_.end()) {
|
||||
std::string err_msg = "Invalid feature type:" + std::to_string(feature_type);
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
*out_feature = itr->second;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphDataClient::RegisterToServer() {
|
||||
RETURN_IF_NOT_OK(CheckPid());
|
||||
void *tag;
|
||||
|
@ -547,16 +600,21 @@ Status GraphDataClient::RegisterToServer() {
|
|||
shared_memory_key_ = static_cast<key_t>(response.shared_memory_key());
|
||||
shared_memory_size_ = response.shared_memory_size();
|
||||
MS_LOG(INFO) << "Register success, recv data_schema:" << response.data_schema();
|
||||
for (auto feature_info : response.default_node_feature()) {
|
||||
for (const auto &feature_info : response.default_node_feature()) {
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
RETURN_IF_NOT_OK(PbToTensor(&feature_info.feature(), &tensor));
|
||||
default_node_feature_map_[feature_info.type()] = tensor;
|
||||
}
|
||||
for (auto feature_info : response.default_edge_feature()) {
|
||||
for (const auto &feature_info : response.default_edge_feature()) {
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
RETURN_IF_NOT_OK(PbToTensor(&feature_info.feature(), &tensor));
|
||||
default_edge_feature_map_[feature_info.type()] = tensor;
|
||||
}
|
||||
for (const auto &feature_info : response.graph_feature()) {
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
RETURN_IF_NOT_OK(PbToTensor(&feature_info.feature(), &tensor));
|
||||
graph_feature_map_[feature_info.type()] = tensor;
|
||||
}
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED(response.error_msg());
|
||||
}
|
||||
|
@ -611,8 +669,11 @@ Status GraphDataClient::InitFeatureParser() {
|
|||
graph_shared_memory_ = std::make_unique<GraphSharedMemory>(shared_memory_size_, shared_memory_key_);
|
||||
RETURN_IF_NOT_OK(graph_shared_memory_->GetSharedMemory());
|
||||
// build feature parser
|
||||
graph_feature_parser_ = std::make_unique<GraphFeatureParser>(ShardColumn(data_schema_));
|
||||
|
||||
if (data_schema_ != nullptr) {
|
||||
graph_feature_parser_ = std::make_unique<GraphFeatureParser>(ShardColumn(data_schema_));
|
||||
} else {
|
||||
MS_LOG(INFO) << "data_schema is no used, as input data is array for creating graph.";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -125,7 +125,7 @@ class GraphDataClient : public GraphData {
|
|||
// Get the feature of a node
|
||||
// @param std::shared_ptr<Tensor> nodes - List of nodes
|
||||
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
||||
// does not exist.
|
||||
// does not exist.
|
||||
// @param TensorRow *out - Returned features
|
||||
// @return Status The status code returned
|
||||
Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
|
||||
|
@ -134,12 +134,19 @@ class GraphDataClient : public GraphData {
|
|||
// Get the feature of a edge
|
||||
// @param std::shared_ptr<Tensor> edges - List of edges
|
||||
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
||||
// does not exist.
|
||||
// does not exist.
|
||||
// @param Tensor *out - Returned features
|
||||
// @return Status The status code returned
|
||||
Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
|
||||
TensorRow *out) override;
|
||||
|
||||
// Get the feature in graph level
|
||||
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
||||
// does not exist.
|
||||
// @param Tensor *out - Returned features
|
||||
// @return Status The status code returned
|
||||
Status GetGraphFeature(const std::vector<FeatureType> &feature_types, TensorRow *out) override;
|
||||
|
||||
// Return meta information to python layer
|
||||
Status GraphInfo(py::dict *out) override;
|
||||
|
||||
|
@ -155,6 +162,8 @@ class GraphDataClient : public GraphData {
|
|||
|
||||
Status GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature);
|
||||
|
||||
Status GetStoredGraphFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature);
|
||||
|
||||
Status GetGraphData(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response);
|
||||
|
||||
Status GetGraphDataTensor(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response,
|
||||
|
@ -186,6 +195,7 @@ class GraphDataClient : public GraphData {
|
|||
std::unique_ptr<GraphSharedMemory> graph_shared_memory_;
|
||||
std::unordered_map<FeatureType, std::shared_ptr<Tensor>> default_node_feature_map_;
|
||||
std::unordered_map<FeatureType, std::shared_ptr<Tensor>> default_edge_feature_map_;
|
||||
std::unordered_map<FeatureType, std::shared_ptr<Tensor>> graph_feature_map_;
|
||||
#endif
|
||||
bool registered_;
|
||||
};
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -23,13 +23,16 @@
|
|||
|
||||
#include "minddata/dataset/core/tensor_shape.h"
|
||||
#include "minddata/dataset/engine/gnn/graph_loader.h"
|
||||
#include "minddata/dataset/engine/gnn/graph_loader_array.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
GraphDataImpl::GraphDataImpl(const std::string &dataset_file, int32_t num_workers, bool server_mode)
|
||||
: dataset_file_(dataset_file),
|
||||
GraphDataImpl::GraphDataImpl(const std::string &data_format, const std::string &dataset_file, int32_t num_workers,
|
||||
bool server_mode)
|
||||
: data_format_(data_format),
|
||||
dataset_file_(dataset_file),
|
||||
num_workers_(num_workers),
|
||||
rnd_(GetRandomDevice()),
|
||||
random_walk_(this),
|
||||
|
@ -123,9 +126,10 @@ Status GraphDataImpl::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list
|
|||
std::string err_msg = "Invalid edge id:" + std::to_string(edge_id);
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
} else {
|
||||
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> nodes;
|
||||
RETURN_IF_NOT_OK(itr->second->GetNode(&nodes));
|
||||
node_list.push_back({nodes.first->id(), nodes.second->id()});
|
||||
NodeIdType src_id, dst_id;
|
||||
RETURN_UNEXPECTED_IF_NULL(itr->second);
|
||||
RETURN_IF_NOT_OK(itr->second->GetNode(&src_id, &dst_id));
|
||||
node_list.push_back({src_id, dst_id});
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(node_list, DataType(DataType::DE_INT32), out));
|
||||
|
@ -276,7 +280,7 @@ Status GraphDataImpl::GetSampledNeighbors(const std::vector<NodeIdType> &node_li
|
|||
std::shared_ptr<Node> node;
|
||||
RETURN_IF_NOT_OK(GetNodeByNodeId(node_id, &node));
|
||||
std::vector<NodeIdType> out;
|
||||
RETURN_IF_NOT_OK(node->GetSampledNeighbors(neighbor_types[i], neighbor_nums[i], strategy, &out));
|
||||
RETURN_IF_NOT_OK(node->GetSampledNeighbors(neighbor_types[i], neighbor_nums[i], strategy, &out, &rnd_));
|
||||
neighbors.insert(neighbors.end(), out.begin(), out.end());
|
||||
}
|
||||
}
|
||||
|
@ -565,8 +569,47 @@ Status GraphDataImpl::GetEdgeFeatureSharedMemory(const std::shared_ptr<Tensor> &
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphDataImpl::GetGraphFeature(const std::vector<FeatureType> &feature_types, TensorRow *out) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Input feature_types is empty.");
|
||||
RETURN_UNEXPECTED_IF_NULL(out);
|
||||
TensorRow tensors;
|
||||
for (const auto &type : feature_types) {
|
||||
std::shared_ptr<Feature> feature;
|
||||
auto itr = graph_feature_map_.find(type);
|
||||
if (itr == graph_feature_map_.end()) {
|
||||
std::string err_msg = "Invalid feature type:" + std::to_string(type);
|
||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||
}
|
||||
feature = itr->second;
|
||||
tensors.push_back(feature->Value());
|
||||
}
|
||||
*out = std::move(tensors);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphDataImpl::Init() {
|
||||
RETURN_IF_NOT_OK(LoadNodeAndEdge());
|
||||
if (data_format_ != "mindrecord") {
|
||||
RETURN_STATUS_UNEXPECTED("Data Format should be `mindrecord` as dataset file is provided.");
|
||||
}
|
||||
GraphLoader gl(this, dataset_file_, num_workers_, server_mode_);
|
||||
|
||||
// ask graph_loader to load everything into memory
|
||||
RETURN_IF_NOT_OK(gl.InitAndLoad());
|
||||
RETURN_IF_NOT_OK(gl.GetNodesAndEdges());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphDataImpl::Init(int32_t num_nodes, const std::shared_ptr<Tensor> &edge,
|
||||
const std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> &node_feat,
|
||||
const std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> &edge_feat,
|
||||
const std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> &graph_feat,
|
||||
const std::shared_ptr<Tensor> &node_type, const std::shared_ptr<Tensor> &edge_type) {
|
||||
MS_LOG(INFO) << "Create graph with loading numpy array data.";
|
||||
GraphLoaderFromArray gl(this, num_nodes, edge, node_feat, edge_feat, graph_feat, node_type, edge_type, num_workers_,
|
||||
server_mode_);
|
||||
RETURN_IF_NOT_OK(gl.InitAndLoad());
|
||||
RETURN_IF_NOT_OK(gl.GetNodesAndEdges());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -607,6 +650,10 @@ Status GraphDataImpl::GetMetaInfo(MetaInfo *meta_info) {
|
|||
std::sort(meta_info->edge_feature_type.begin(), meta_info->edge_feature_type.end());
|
||||
auto unique_edge = std::unique(meta_info->edge_feature_type.begin(), meta_info->edge_feature_type.end());
|
||||
meta_info->edge_feature_type.erase(unique_edge, meta_info->edge_feature_type.end());
|
||||
|
||||
for (const auto &graph_feature : graph_feature_map_) {
|
||||
meta_info->graph_feature_type.emplace_back(graph_feature.first);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -621,19 +668,11 @@ Status GraphDataImpl::GraphInfo(py::dict *out) {
|
|||
(*out)["edge_num"] = py::cast(meta_info.edge_num);
|
||||
(*out)["node_feature_type"] = py::cast(meta_info.node_feature_type);
|
||||
(*out)["edge_feature_type"] = py::cast(meta_info.edge_feature_type);
|
||||
(*out)["graph_feature_type"] = py::cast(meta_info.graph_feature_type);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
Status GraphDataImpl::LoadNodeAndEdge() {
|
||||
GraphLoader gl(this, dataset_file_, num_workers_, server_mode_);
|
||||
// ask graph_loader to load everything into memory
|
||||
RETURN_IF_NOT_OK(gl.InitAndLoad());
|
||||
// get all maps
|
||||
RETURN_IF_NOT_OK(gl.GetNodesAndEdges());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphDataImpl::GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node) {
|
||||
RETURN_UNEXPECTED_IF_NULL(node);
|
||||
auto itr = node_id_map_.find(id);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -42,9 +42,11 @@ using StochasticIndex = std::pair<std::vector<int32_t>, std::vector<float>>;
|
|||
class GraphDataImpl : public GraphData {
|
||||
public:
|
||||
// Constructor
|
||||
// @param std::string data_format - support mindrecord or array
|
||||
// @param std::string dataset_file -
|
||||
// @param int32_t num_workers - number of parallel threads
|
||||
GraphDataImpl(const std::string &dataset_file, int32_t num_workers, bool server_mode = false);
|
||||
GraphDataImpl(const std::string &data_format, const std::string &dataset_file, int32_t num_workers,
|
||||
bool server_mode = false);
|
||||
|
||||
~GraphDataImpl() override;
|
||||
|
||||
|
@ -119,7 +121,7 @@ class GraphDataImpl : public GraphData {
|
|||
// Get the feature of a node
|
||||
// @param std::shared_ptr<Tensor> nodes - List of nodes
|
||||
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
||||
// does not exist.
|
||||
// does not exist.
|
||||
// @param TensorRow *out - Returned features
|
||||
// @return Status The status code returned
|
||||
Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
|
||||
|
@ -131,7 +133,7 @@ class GraphDataImpl : public GraphData {
|
|||
// Get the feature of a edge
|
||||
// @param std::shared_ptr<Tensor> edges - List of edges
|
||||
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
||||
// does not exist.
|
||||
// does not exist.
|
||||
// @param Tensor *out - Returned features
|
||||
// @return Status The status code returned
|
||||
Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
|
||||
|
@ -140,6 +142,13 @@ class GraphDataImpl : public GraphData {
|
|||
Status GetEdgeFeatureSharedMemory(const std::shared_ptr<Tensor> &edges, FeatureType type,
|
||||
std::shared_ptr<Tensor> *out);
|
||||
|
||||
// Get the feature in graph level
|
||||
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
||||
// does not exist.
|
||||
// @param Tensor *out - Returned features
|
||||
// @return Status The status code returned
|
||||
Status GetGraphFeature(const std::vector<FeatureType> &feature_types, TensorRow *out) override;
|
||||
|
||||
// Get meta information of graph
|
||||
// @param MetaInfo *meta_info - Returned meta information
|
||||
// @return Status The status code returned
|
||||
|
@ -158,8 +167,18 @@ class GraphDataImpl : public GraphData {
|
|||
return &default_edge_feature_map_;
|
||||
}
|
||||
|
||||
const std::unordered_map<FeatureType, std::shared_ptr<Feature>> *GetAllGraphFeatures() const {
|
||||
return &graph_feature_map_;
|
||||
}
|
||||
|
||||
Status Init() override;
|
||||
|
||||
Status Init(int32_t num_nodes, const std::shared_ptr<Tensor> &edge,
|
||||
const std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> &node_feat,
|
||||
const std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> &edge_feat,
|
||||
const std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> &graph_feat,
|
||||
const std::shared_ptr<Tensor> &node_type, const std::shared_ptr<Tensor> &edge_type) override;
|
||||
|
||||
Status Stop() override { return Status::OK(); }
|
||||
|
||||
std::string GetDataSchema() { return data_schema_.dump(); }
|
||||
|
@ -172,6 +191,7 @@ class GraphDataImpl : public GraphData {
|
|||
|
||||
private:
|
||||
friend class GraphLoader;
|
||||
friend class GraphLoaderFromArray;
|
||||
class RandomWalkBase {
|
||||
public:
|
||||
explicit RandomWalkBase(GraphDataImpl *graph);
|
||||
|
@ -211,10 +231,6 @@ class GraphDataImpl : public GraphData {
|
|||
int32_t num_workers_; // The number of worker threads. Default is 1
|
||||
};
|
||||
|
||||
// Load graph data from mindrecord file
|
||||
// @return Status The status code returned
|
||||
Status LoadNodeAndEdge();
|
||||
|
||||
// Create Tensor By Vector
|
||||
// @param std::vector<std::vector<T>> &data -
|
||||
// @param DataType type -
|
||||
|
@ -269,6 +285,7 @@ class GraphDataImpl : public GraphData {
|
|||
|
||||
Status CheckNeighborType(NodeType neighbor_type);
|
||||
|
||||
std::string data_format_;
|
||||
std::string dataset_file_;
|
||||
int32_t num_workers_; // The number of worker threads
|
||||
std::mt19937 rnd_;
|
||||
|
@ -287,6 +304,7 @@ class GraphDataImpl : public GraphData {
|
|||
std::unordered_map<NodeType, std::unordered_set<FeatureType>> node_feature_map_;
|
||||
std::unordered_map<EdgeType, std::unordered_set<FeatureType>> edge_feature_map_;
|
||||
|
||||
std::unordered_map<FeatureType, std::shared_ptr<Feature>> graph_feature_map_;
|
||||
std::unordered_map<FeatureType, std::shared_ptr<Feature>> default_node_feature_map_;
|
||||
std::unordered_map<FeatureType, std::shared_ptr<Feature>> default_edge_feature_map_;
|
||||
};
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -26,16 +26,17 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
GraphDataServer::GraphDataServer(const std::string &dataset_file, int32_t num_workers, const std::string &hostname,
|
||||
int32_t port, int32_t client_num, bool auto_shutdown)
|
||||
: dataset_file_(dataset_file),
|
||||
GraphDataServer::GraphDataServer(const std::string &data_format, const std::string &dataset_file, int32_t num_workers,
|
||||
const std::string &hostname, int32_t port, int32_t client_num, bool auto_shutdown)
|
||||
: data_format_(data_format),
|
||||
dataset_file_(dataset_file),
|
||||
num_workers_(num_workers),
|
||||
client_num_(client_num),
|
||||
max_connected_client_num_(0),
|
||||
auto_shutdown_(auto_shutdown),
|
||||
state_(kGdsUninit) {
|
||||
tg_ = std::make_unique<TaskGroup>();
|
||||
graph_data_impl_ = std::make_unique<GraphDataImpl>(dataset_file, num_workers, true);
|
||||
graph_data_impl_ = std::make_unique<GraphDataImpl>(data_format, dataset_file, num_workers, true);
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
service_impl_ = std::make_unique<GraphDataServiceImpl>(this, graph_data_impl_.get());
|
||||
async_server_ = std::make_unique<GraphDataGrpcServer>(hostname, port, service_impl_.get());
|
||||
|
@ -61,6 +62,31 @@ Status GraphDataServer::Init() {
|
|||
#endif
|
||||
}
|
||||
|
||||
Status GraphDataServer::Init(int32_t num_nodes, const std::shared_ptr<Tensor> &edge,
|
||||
const std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> &node_feat,
|
||||
const std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> &edge_feat,
|
||||
const std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> &graph_feat,
|
||||
const std::shared_ptr<Tensor> &node_type, const std::shared_ptr<Tensor> &edge_type) {
|
||||
#if defined(_WIN32) || defined(_WIN64)
|
||||
RETURN_STATUS_UNEXPECTED("Graph data server is not supported in Windows OS");
|
||||
#else
|
||||
set_state(kGdsInitializing);
|
||||
RETURN_IF_NOT_OK(async_server_->Run());
|
||||
RETURN_IF_NOT_OK(tg_->CreateAsyncTask(
|
||||
"init graph data impl", std::bind(&GraphDataServer::InitNumpyGraphDataImpl, this, num_nodes, edge, node_feat,
|
||||
edge_feat, graph_feat, node_type, edge_type)));
|
||||
for (int32_t i = 0; i < num_workers_; ++i) {
|
||||
RETURN_IF_NOT_OK(
|
||||
tg_->CreateAsyncTask("start async rpc service", std::bind(&GraphDataServer::StartAsyncRpcService, this)));
|
||||
}
|
||||
if (auto_shutdown_) {
|
||||
RETURN_IF_NOT_OK(
|
||||
tg_->CreateAsyncTask("judge auto shutdown server", std::bind(&GraphDataServer::JudgeAutoShutdownServer, this)));
|
||||
}
|
||||
return Status::OK();
|
||||
#endif
|
||||
}
|
||||
|
||||
Status GraphDataServer::InitGraphDataImpl() {
|
||||
TaskManager::FindMe()->Post();
|
||||
Status s = graph_data_impl_->Init();
|
||||
|
@ -73,6 +99,21 @@ Status GraphDataServer::InitGraphDataImpl() {
|
|||
}
|
||||
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
Status GraphDataServer::InitNumpyGraphDataImpl(int32_t num_nodes, std::shared_ptr<Tensor> edge,
|
||||
std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> node_feat,
|
||||
std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> edge_feat,
|
||||
std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> graph_feat,
|
||||
std::shared_ptr<Tensor> node_type, std::shared_ptr<Tensor> edge_type) {
|
||||
TaskManager::FindMe()->Post();
|
||||
Status s = graph_data_impl_->Init(num_nodes, edge, node_feat, edge_feat, graph_feat, node_type, edge_type);
|
||||
if (s.IsOk()) {
|
||||
set_state(kGdsRunning);
|
||||
} else {
|
||||
(void)Stop();
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
Status GraphDataServer::StartAsyncRpcService() {
|
||||
TaskManager::FindMe()->Post();
|
||||
RETURN_IF_NOT_OK(async_server_->HandleRequest());
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -19,6 +19,7 @@
|
|||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
|
@ -26,6 +27,7 @@
|
|||
#include "minddata/dataset/engine/gnn/graph_data_service_impl.h"
|
||||
#include "minddata/dataset/engine/gnn/grpc_async_server.h"
|
||||
#endif
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/util/task_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -37,12 +39,19 @@ class GraphDataImpl;
|
|||
class GraphDataServer {
|
||||
public:
|
||||
enum ServerState { kGdsUninit = 0, kGdsInitializing, kGdsRunning, kGdsStopped };
|
||||
GraphDataServer(const std::string &dataset_file, int32_t num_workers, const std::string &hostname, int32_t port,
|
||||
int32_t client_num, bool auto_shutdown);
|
||||
GraphDataServer(const std::string &data_format, const std::string &dataset_file, int32_t num_workers,
|
||||
const std::string &hostname, int32_t port, int32_t client_num, bool auto_shutdown);
|
||||
|
||||
~GraphDataServer() = default;
|
||||
|
||||
Status Init();
|
||||
|
||||
Status Init(int32_t num_nodes, const std::shared_ptr<Tensor> &edge,
|
||||
const std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> &node_feat,
|
||||
const std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> &edge_feat,
|
||||
const std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> &graph_feat,
|
||||
const std::shared_ptr<Tensor> &node_type, const std::shared_ptr<Tensor> &edge_type);
|
||||
|
||||
Status Stop();
|
||||
|
||||
Status ClientRegister(int32_t pid);
|
||||
|
@ -62,11 +71,19 @@ class GraphDataServer {
|
|||
void set_state(enum ServerState state) { state_ = state; }
|
||||
|
||||
Status InitGraphDataImpl();
|
||||
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
Status InitNumpyGraphDataImpl(int32_t num_nodes, std::shared_ptr<Tensor> edge,
|
||||
std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> node_feat,
|
||||
std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> edge_feat,
|
||||
std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> graph_feat,
|
||||
std::shared_ptr<Tensor> node_type, std::shared_ptr<Tensor> edge_type);
|
||||
|
||||
Status StartAsyncRpcService();
|
||||
#endif
|
||||
Status JudgeAutoShutdownServer();
|
||||
|
||||
std::string data_format_;
|
||||
std::string dataset_file_;
|
||||
int32_t num_workers_; // The number of worker threads
|
||||
int32_t client_num_;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -45,17 +45,24 @@ GraphDataServiceImpl::GraphDataServiceImpl(GraphDataServer *server, GraphDataImp
|
|||
|
||||
Status GraphDataServiceImpl::FillDefaultFeature(GnnClientRegisterResponsePb *response) {
|
||||
const auto default_node_features = graph_data_impl_->GetAllDefaultNodeFeatures();
|
||||
for (const auto feature : *default_node_features) {
|
||||
for (const auto &feature : *default_node_features) {
|
||||
GnnFeatureInfoPb *feature_info = response->add_default_node_feature();
|
||||
feature_info->set_type(feature.first);
|
||||
RETURN_IF_NOT_OK(TensorToPb(feature.second->Value(), feature_info->mutable_feature()));
|
||||
}
|
||||
const auto default_edge_features = graph_data_impl_->GetAllDefaultEdgeFeatures();
|
||||
for (const auto feature : *default_edge_features) {
|
||||
for (const auto &feature : *default_edge_features) {
|
||||
GnnFeatureInfoPb *feature_info = response->add_default_edge_feature();
|
||||
feature_info->set_type(feature.first);
|
||||
RETURN_IF_NOT_OK(TensorToPb(feature.second->Value(), feature_info->mutable_feature()));
|
||||
}
|
||||
const auto graph_features = graph_data_impl_->GetAllGraphFeatures();
|
||||
for (const auto &feature : *graph_features) {
|
||||
GnnFeatureInfoPb *feature_info = response->add_graph_feature();
|
||||
feature_info->set_type(feature.first);
|
||||
RETURN_IF_NOT_OK(TensorToPb(feature.second->Value(), feature_info->mutable_feature()));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -24,6 +24,7 @@
|
|||
#include "minddata/dataset/engine/gnn/local_node.h"
|
||||
#include "minddata/dataset/util/task_manager.h"
|
||||
#include "minddata/mindrecord/include/shard_error.h"
|
||||
#include "utils/file_utils.h"
|
||||
|
||||
using ShardTuple = std::vector<std::tuple<std::vector<uint8_t>, mindspore::mindrecord::json>>;
|
||||
namespace mindspore {
|
||||
|
@ -44,10 +45,11 @@ GraphLoader::GraphLoader(GraphDataImpl *graph_impl, std::string mr_filepath, int
|
|||
optional_key_({{"weight", false}}) {}
|
||||
|
||||
Status GraphLoader::GetNodesAndEdges() {
|
||||
MS_LOG(INFO) << "Start to fill node and edges into graph.";
|
||||
NodeIdMap *n_id_map = &graph_impl_->node_id_map_;
|
||||
EdgeIdMap *e_id_map = &graph_impl_->edge_id_map_;
|
||||
for (std::deque<std::shared_ptr<Node>> &dq : n_deques_) {
|
||||
while (dq.empty() == false) {
|
||||
while (!dq.empty()) {
|
||||
std::shared_ptr<Node> node_ptr = dq.front();
|
||||
n_id_map->insert({node_ptr->id(), node_ptr});
|
||||
graph_impl_->node_type_map_[node_ptr->type()].push_back(node_ptr->id());
|
||||
|
@ -56,16 +58,21 @@ Status GraphLoader::GetNodesAndEdges() {
|
|||
}
|
||||
|
||||
for (std::deque<std::shared_ptr<Edge>> &dq : e_deques_) {
|
||||
while (dq.empty() == false) {
|
||||
while (!dq.empty()) {
|
||||
std::shared_ptr<Edge> edge_ptr = dq.front();
|
||||
std::pair<std::shared_ptr<Node>, std::shared_ptr<Node>> p;
|
||||
RETURN_IF_NOT_OK(edge_ptr->GetNode(&p));
|
||||
auto src_itr = n_id_map->find(p.first->id()), dst_itr = n_id_map->find(p.second->id());
|
||||
NodeIdType src_id, dst_id;
|
||||
RETURN_IF_NOT_OK(edge_ptr->GetNode(&src_id, &dst_id));
|
||||
auto src_itr = n_id_map->find(src_id), dst_itr = n_id_map->find(dst_id);
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(src_itr != n_id_map->end(), "invalid src_id.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(dst_itr != n_id_map->end(), "invalid src_id.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
src_itr != n_id_map->end(),
|
||||
"[Internal Error] src node with id '" + std::to_string(src_id) + "' has not been created yet.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
dst_itr != n_id_map->end(),
|
||||
"[Internal Error] dst node with id '" + std::to_string(dst_id) + "' has not been created yet.");
|
||||
|
||||
RETURN_IF_NOT_OK(edge_ptr->SetNode(src_itr->second->id(), dst_itr->second->id()));
|
||||
|
||||
RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second}));
|
||||
RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second, edge_ptr->weight()));
|
||||
RETURN_IF_NOT_OK(src_itr->second->AddAdjacent(dst_itr->second, edge_ptr));
|
||||
|
||||
|
@ -85,6 +92,12 @@ Status GraphLoader::GetNodesAndEdges() {
|
|||
Status GraphLoader::InitAndLoad() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "num_reader can't be < 1\n");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(row_id_ == 0, "InitAndLoad Can only be called once!\n");
|
||||
|
||||
auto realpath = FileUtils::GetRealPath(mr_path_.c_str());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
realpath.has_value(),
|
||||
"Invalid file, failed to get the realpath of mindrecord files. Please check file: " + mr_path_);
|
||||
|
||||
n_deques_.resize(num_workers_);
|
||||
e_deques_.resize(num_workers_);
|
||||
n_feature_maps_.resize(num_workers_);
|
||||
|
@ -118,6 +131,8 @@ Status GraphLoader::InitAndLoad() {
|
|||
RETURN_IF_NOT_OK(shard_reader_->GetTotalBlobSize(&total_blob_size));
|
||||
graph_impl_->graph_shared_memory_ = std::make_unique<GraphSharedMemory>(total_blob_size, mr_path_);
|
||||
RETURN_IF_NOT_OK(graph_impl_->graph_shared_memory_->CreateSharedMemory());
|
||||
#else
|
||||
RETURN_STATUS_UNEXPECTED("Server mode is not supported in Windows OS.");
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -163,6 +178,8 @@ Status GraphLoader::LoadNode(const std::vector<uint8_t> &col_blob, const mindrec
|
|||
(*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
|
||||
}
|
||||
}
|
||||
#else
|
||||
RETURN_STATUS_UNEXPECTED("Server mode is not supported in Windows OS.");
|
||||
#endif
|
||||
} else {
|
||||
for (int32_t ind : indices) {
|
||||
|
@ -192,9 +209,7 @@ Status GraphLoader::LoadEdge(const std::vector<uint8_t> &col_blob, const mindrec
|
|||
if (optional_key_["weight"]) {
|
||||
edge_weight = col_jsn["weight"];
|
||||
}
|
||||
std::shared_ptr<Node> src = std::make_shared<LocalNode>(src_id, -1, 1);
|
||||
std::shared_ptr<Node> dst = std::make_shared<LocalNode>(dst_id, -1, 1);
|
||||
(*edge) = std::make_shared<LocalEdge>(edge_id, edge_type, edge_weight, src, dst);
|
||||
(*edge) = std::make_shared<LocalEdge>(edge_id, edge_type, edge_weight, src_id, dst_id);
|
||||
std::vector<int32_t> indices;
|
||||
RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("edge_feature_index", col_blob, &indices));
|
||||
if (graph_impl_->server_mode_) {
|
||||
|
@ -215,6 +230,8 @@ Status GraphLoader::LoadEdge(const std::vector<uint8_t> &col_blob, const mindrec
|
|||
(*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
|
||||
}
|
||||
}
|
||||
#else
|
||||
RETURN_STATUS_UNEXPECTED("Server mode is not supported in Windows OS.");
|
||||
#endif
|
||||
} else {
|
||||
for (int32_t ind : indices) {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -59,16 +59,27 @@ class GraphLoader {
|
|||
GraphLoader(GraphDataImpl *graph_impl, std::string mr_filepath, int32_t num_workers = 4, bool server_mode = false);
|
||||
|
||||
~GraphLoader() = default;
|
||||
// Init mindrecord and load everything into memory multi-threaded
|
||||
// Init mindrecord or array and load everything into memory multi-threaded
|
||||
// @return Status - the status code
|
||||
Status InitAndLoad();
|
||||
virtual Status InitAndLoad();
|
||||
|
||||
// this function will query mindrecord and construct all nodes and edges
|
||||
// this function will query mindrecord or array and construct all nodes and edges
|
||||
// nodes and edges are added to map without any connection. That's because there nodes and edges are read in
|
||||
// random order. src_node and dst_node in Edge are node_id only with -1 as type.
|
||||
// features attached to each node and edge are expected to be filled correctly
|
||||
Status GetNodesAndEdges();
|
||||
|
||||
protected:
|
||||
// merge NodeFeatureMap and EdgeFeatureMap of each worker into 1
|
||||
void MergeFeatureMaps();
|
||||
std::vector<std::deque<std::shared_ptr<Node>>> n_deques_;
|
||||
std::vector<std::deque<std::shared_ptr<Edge>>> e_deques_;
|
||||
std::vector<NodeFeatureMap> n_feature_maps_;
|
||||
std::vector<EdgeFeatureMap> e_feature_maps_;
|
||||
std::vector<DefaultNodeFeatureMap> default_node_feature_maps_;
|
||||
std::vector<DefaultEdgeFeatureMap> default_edge_feature_maps_;
|
||||
GraphDataImpl *graph_impl_;
|
||||
|
||||
private:
|
||||
//
|
||||
// worker thread that reads mindrecord file
|
||||
|
@ -95,21 +106,11 @@ class GraphLoader {
|
|||
Status LoadEdge(const std::vector<uint8_t> &blob, const mindrecord::json &jsn, std::shared_ptr<Edge> *edge,
|
||||
EdgeFeatureMap *feature_map, DefaultEdgeFeatureMap *default_feature);
|
||||
|
||||
// merge NodeFeatureMap and EdgeFeatureMap of each worker into 1
|
||||
void MergeFeatureMaps();
|
||||
|
||||
GraphDataImpl *graph_impl_;
|
||||
std::string mr_path_;
|
||||
const int32_t num_workers_;
|
||||
std::atomic_int row_id_;
|
||||
std::unique_ptr<ShardReader> shard_reader_;
|
||||
std::unique_ptr<GraphFeatureParser> graph_feature_parser_;
|
||||
std::vector<std::deque<std::shared_ptr<Node>>> n_deques_;
|
||||
std::vector<std::deque<std::shared_ptr<Edge>>> e_deques_;
|
||||
std::vector<NodeFeatureMap> n_feature_maps_;
|
||||
std::vector<EdgeFeatureMap> e_feature_maps_;
|
||||
std::vector<DefaultNodeFeatureMap> default_node_feature_maps_;
|
||||
std::vector<DefaultEdgeFeatureMap> default_edge_feature_maps_;
|
||||
const std::vector<std::string> required_key_;
|
||||
std::unordered_map<std::string, bool> optional_key_;
|
||||
};
|
||||
|
|
|
@ -0,0 +1,287 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/engine/gnn/graph_loader_array.h"
|
||||
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
#include <sys/ipc.h>
|
||||
#endif
|
||||
|
||||
#include <unistd.h>
|
||||
#include <future>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
|
||||
#include "minddata/dataset/engine/gnn/local_edge.h"
|
||||
#include "minddata/dataset/engine/gnn/local_node.h"
|
||||
#include "minddata/dataset/util/task_manager.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
const FeatureType weight_feature_type = -1;
|
||||
|
||||
GraphLoaderFromArray::GraphLoaderFromArray(GraphDataImpl *graph_impl, int32_t num_nodes,
|
||||
const std::shared_ptr<Tensor> &edge,
|
||||
const std::unordered_map<FeatureType, std::shared_ptr<Tensor>> &node_feat,
|
||||
const std::unordered_map<FeatureType, std::shared_ptr<Tensor>> &edge_feat,
|
||||
const std::unordered_map<FeatureType, std::shared_ptr<Tensor>> &graph_feat,
|
||||
const std::shared_ptr<Tensor> &node_type,
|
||||
const std::shared_ptr<Tensor> &edge_type, int32_t num_workers,
|
||||
bool server_mode)
|
||||
: GraphLoader(graph_impl, "", num_workers),
|
||||
num_nodes_(num_nodes),
|
||||
edge_(edge),
|
||||
node_feat_(node_feat),
|
||||
edge_feat_(edge_feat),
|
||||
graph_feat_(graph_feat),
|
||||
node_type_(node_type),
|
||||
edge_type_(edge_type),
|
||||
num_workers_(num_workers) {}
|
||||
|
||||
Status GraphLoaderFromArray::InitAndLoad() {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_workers_ > 0, "num_workers should be equal or great than 1.");
|
||||
n_deques_.resize(num_workers_);
|
||||
e_deques_.resize(num_workers_);
|
||||
n_feature_maps_.resize(num_workers_);
|
||||
e_feature_maps_.resize(num_workers_);
|
||||
default_node_feature_maps_.resize(num_workers_);
|
||||
default_edge_feature_maps_.resize(num_workers_);
|
||||
TaskGroup vg;
|
||||
|
||||
if (graph_impl_->server_mode_) {
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
// obtain the size that required for store feature, if add_node or add_edge later, this should be larger initially
|
||||
int64_t total_feature_size = 0;
|
||||
total_feature_size = std::accumulate(node_feat_.begin(), node_feat_.end(), total_feature_size,
|
||||
[](int64_t temp_size, std::pair<FeatureType, std::shared_ptr<Tensor>> item) {
|
||||
return temp_size + item.second->SizeInBytes();
|
||||
});
|
||||
total_feature_size = std::accumulate(edge_feat_.begin(), edge_feat_.end(), total_feature_size,
|
||||
[](int64_t temp_size, std::pair<FeatureType, std::shared_ptr<Tensor>> item) {
|
||||
return temp_size + item.second->SizeInBytes();
|
||||
});
|
||||
|
||||
MS_LOG(INFO) << "Total feature size in input data is(byte):" << total_feature_size;
|
||||
|
||||
// generate memory_key
|
||||
char file_name[] = "/tmp/tempfile_XXXXXX";
|
||||
int fd = mkstemp(file_name);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(fd != -1, "create temp file failed when create graph with loading array data.");
|
||||
auto memory_key = ftok(file_name, kGnnSharedMemoryId);
|
||||
auto err = unlink(file_name);
|
||||
std::string err_msg = "unable to delete file:";
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(err != -1, err_msg + file_name);
|
||||
|
||||
close(fd);
|
||||
graph_impl_->graph_shared_memory_ = std::make_unique<GraphSharedMemory>(total_feature_size, memory_key);
|
||||
RETURN_IF_NOT_OK(graph_impl_->graph_shared_memory_->CreateSharedMemory());
|
||||
#else
|
||||
RETURN_STATUS_UNEXPECTED("Server mode is not supported in Windows OS.");
|
||||
#endif
|
||||
}
|
||||
|
||||
// load graph feature into memory firstly
|
||||
for (const auto &item : graph_feat_) {
|
||||
graph_impl_->graph_feature_map_[item.first] = std::make_shared<Feature>(item.first, item.second);
|
||||
}
|
||||
graph_feat_.clear();
|
||||
|
||||
// deal with weight in node and edge firstly
|
||||
auto weight_itr = node_feat_.find(weight_feature_type);
|
||||
if (weight_itr != node_feat_.end()) {
|
||||
node_weight_ = weight_itr->second;
|
||||
node_feat_.erase(weight_feature_type);
|
||||
}
|
||||
weight_itr = edge_feat_.find(weight_feature_type);
|
||||
if (weight_itr != edge_feat_.end()) {
|
||||
edge_weight_ = weight_itr->second;
|
||||
edge_feat_.erase(weight_feature_type);
|
||||
}
|
||||
|
||||
for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) {
|
||||
RETURN_IF_NOT_OK(
|
||||
vg.CreateAsyncTask("GraphLoaderFromArray", std::bind(&GraphLoaderFromArray::WorkerEntry, this, wkr_id)));
|
||||
}
|
||||
|
||||
// wait for threads to finish and check its return code
|
||||
RETURN_IF_NOT_OK(vg.join_all(Task::WaitFlag::kBlocking));
|
||||
RETURN_IF_NOT_OK(vg.GetTaskErrorIfAny());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphLoaderFromArray::WorkerEntry(int32_t worker_id) {
|
||||
TaskManager::FindMe()->Post();
|
||||
RETURN_IF_NOT_OK(LoadNode(worker_id));
|
||||
RETURN_IF_NOT_OK(LoadEdge(worker_id));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphLoaderFromArray::LoadNode(int32_t worker_id) {
|
||||
MS_LOG(INFO) << "start Load Node, worker id is:" << worker_id;
|
||||
for (NodeIdType i = worker_id; i < num_nodes_; i = i + num_workers_) {
|
||||
WeightType weight = 1.0;
|
||||
NodeType node_type;
|
||||
if (node_weight_ != nullptr) {
|
||||
RETURN_IF_NOT_OK(node_weight_->GetItemAt<WeightType>(&weight, {i}));
|
||||
}
|
||||
RETURN_IF_NOT_OK(node_type_->GetItemAt<NodeType>(&node_type, {i}));
|
||||
std::shared_ptr<Node> node_ptr = std::make_shared<LocalNode>(i, node_type, weight);
|
||||
|
||||
if (graph_impl_->server_mode_) {
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
for (const auto &item : node_feat_) {
|
||||
std::shared_ptr<Tensor> tensor_sm;
|
||||
RETURN_IF_NOT_OK(LoadFeatureToSharedMemory(i, item, &tensor_sm));
|
||||
RETURN_IF_NOT_OK(node_ptr->UpdateFeature(std::make_shared<Feature>(item.first, tensor_sm, true)));
|
||||
n_feature_maps_[worker_id][node_type].insert(item.first);
|
||||
|
||||
// this may only need execute once, as all node has the same feature type
|
||||
if (default_node_feature_maps_[worker_id][item.first] == nullptr) {
|
||||
std::shared_ptr<Tensor> tensor = nullptr;
|
||||
std::shared_ptr<Tensor> zero_tensor;
|
||||
RETURN_IF_NOT_OK(LoadFeatureTensor(i, item, &tensor));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor));
|
||||
RETURN_IF_NOT_OK(zero_tensor->Zero());
|
||||
default_node_feature_maps_[worker_id][item.first] = std::make_shared<Feature>(item.first, zero_tensor);
|
||||
}
|
||||
}
|
||||
#else
|
||||
RETURN_STATUS_UNEXPECTED("Server mode is not supported in Windows OS.");
|
||||
#endif
|
||||
} else {
|
||||
for (const auto &item : node_feat_) {
|
||||
// get one row in corresponding node_feature
|
||||
std::shared_ptr<Tensor> feature_item;
|
||||
RETURN_IF_NOT_OK(LoadFeatureTensor(i, item, &feature_item));
|
||||
|
||||
RETURN_IF_NOT_OK(node_ptr->UpdateFeature(std::make_shared<Feature>(item.first, feature_item)));
|
||||
n_feature_maps_[worker_id][node_type].insert(item.first);
|
||||
// this may only need execute once, as all node has the same feature type
|
||||
if (default_node_feature_maps_[worker_id][item.first] == nullptr) {
|
||||
std::shared_ptr<Tensor> zero_tensor;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(feature_item->shape(), feature_item->type(), &zero_tensor));
|
||||
RETURN_IF_NOT_OK(zero_tensor->Zero());
|
||||
default_node_feature_maps_[worker_id][item.first] = std::make_shared<Feature>(item.first, zero_tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
n_deques_[worker_id].emplace_back(node_ptr);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GraphLoaderFromArray::LoadEdge(int32_t worker_id) {
|
||||
MS_LOG(INFO) << "Start Load Edge, worker id is:" << worker_id;
|
||||
RETURN_UNEXPECTED_IF_NULL(edge_);
|
||||
auto num_edges = edge_->shape()[1];
|
||||
for (EdgeIdType i = worker_id; i < num_edges; i = i + num_workers_) {
|
||||
// if weight exist in feature, then update it
|
||||
WeightType weight = 1.0;
|
||||
if (edge_weight_ != nullptr) {
|
||||
RETURN_IF_NOT_OK(edge_weight_->GetItemAt<WeightType>(&weight, {i}));
|
||||
}
|
||||
NodeIdType src_id, dst_id;
|
||||
EdgeType edge_type;
|
||||
RETURN_IF_NOT_OK(edge_->GetItemAt<NodeIdType>(&src_id, {0, i}));
|
||||
RETURN_IF_NOT_OK(edge_->GetItemAt<NodeIdType>(&dst_id, {1, i}));
|
||||
RETURN_IF_NOT_OK(edge_type_->GetItemAt<EdgeType>(&edge_type, {i}));
|
||||
|
||||
std::shared_ptr<Edge> edge_ptr = std::make_shared<LocalEdge>(i, edge_type, weight, src_id, dst_id);
|
||||
if (graph_impl_->server_mode_) {
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
for (const auto &item : edge_feat_) {
|
||||
std::shared_ptr<Tensor> tensor_sm;
|
||||
RETURN_IF_NOT_OK(LoadFeatureToSharedMemory(i, item, &tensor_sm));
|
||||
RETURN_IF_NOT_OK(edge_ptr->UpdateFeature(std::make_shared<Feature>(item.first, tensor_sm, true)));
|
||||
e_feature_maps_[worker_id][edge_type].insert(item.first);
|
||||
|
||||
// this may only need execute once, as all node has the same feature type
|
||||
if (default_edge_feature_maps_[worker_id][item.first] == nullptr) {
|
||||
std::shared_ptr<Tensor> tensor = nullptr;
|
||||
std::shared_ptr<Tensor> zero_tensor;
|
||||
RETURN_IF_NOT_OK(LoadFeatureTensor(i, item, &tensor));
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor));
|
||||
RETURN_IF_NOT_OK(zero_tensor->Zero());
|
||||
default_edge_feature_maps_[worker_id][item.first] = std::make_shared<Feature>(item.first, zero_tensor);
|
||||
}
|
||||
}
|
||||
#else
|
||||
RETURN_STATUS_UNEXPECTED("Server mode is not supported in Windows OS.");
|
||||
#endif
|
||||
} else {
|
||||
for (const auto &item : edge_feat_) {
|
||||
std::shared_ptr<Tensor> feature_item;
|
||||
RETURN_IF_NOT_OK(LoadFeatureTensor(i, item, &feature_item));
|
||||
|
||||
RETURN_IF_NOT_OK(edge_ptr->UpdateFeature(std::make_shared<Feature>(item.first, feature_item)));
|
||||
e_feature_maps_[worker_id][edge_type].insert(item.first);
|
||||
// this may only need execute once, as all node has the same feature type
|
||||
if (default_edge_feature_maps_[worker_id][item.first] == nullptr) {
|
||||
std::shared_ptr<Tensor> zero_tensor;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(feature_item->shape(), feature_item->type(), &zero_tensor));
|
||||
RETURN_IF_NOT_OK(zero_tensor->Zero());
|
||||
default_edge_feature_maps_[worker_id][item.first] = std::make_shared<Feature>(item.first, zero_tensor);
|
||||
}
|
||||
}
|
||||
}
|
||||
e_deques_[worker_id].emplace_back(edge_ptr);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
Status GraphLoaderFromArray::LoadFeatureToSharedMemory(int32_t i, std::pair<int16_t, std::shared_ptr<Tensor>> item,
|
||||
std::shared_ptr<Tensor> *out_tensor) {
|
||||
auto feature_num = item.second->shape()[1];
|
||||
uint8_t type_size = item.second->type().SizeInBytes();
|
||||
dsize_t src_flat_ind = 0;
|
||||
RETURN_IF_NOT_OK(item.second->shape().ToFlatIndex({i, 0}, &src_flat_ind));
|
||||
auto start_ptr = item.second->GetBuffer() + src_flat_ind * type_size;
|
||||
|
||||
dsize_t n_bytes = feature_num * type_size;
|
||||
int64_t offset = 0;
|
||||
auto shared_memory = graph_impl_->graph_shared_memory_.get();
|
||||
RETURN_IF_NOT_OK(shared_memory->InsertData(start_ptr, n_bytes, &offset));
|
||||
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(std::move(TensorShape({2})), std::move(DataType(DataType::DE_INT64)), &tensor));
|
||||
auto fea_itr = tensor->begin<int64_t>();
|
||||
*fea_itr = offset;
|
||||
++fea_itr;
|
||||
*fea_itr = n_bytes;
|
||||
*out_tensor = std::move(tensor);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
Status GraphLoaderFromArray::LoadFeatureTensor(int32_t i, std::pair<int16_t, std::shared_ptr<Tensor>> item,
|
||||
std::shared_ptr<Tensor> *tensor) {
|
||||
std::shared_ptr<Tensor> feature_item;
|
||||
auto feature_num = item.second->shape()[1];
|
||||
uint8_t type_size = item.second->type().SizeInBytes();
|
||||
dsize_t src_flat_ind = 0;
|
||||
RETURN_IF_NOT_OK(item.second->shape().ToFlatIndex({i, 0}, &src_flat_ind));
|
||||
auto start_ptr = item.second->GetBuffer() + src_flat_ind * type_size;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(TensorShape({feature_num}), item.second->type(), start_ptr, &feature_item));
|
||||
|
||||
*tensor = std::move(feature_item);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,109 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_LOADER_ARRAY_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_LOADER_ARRAY_H_
|
||||
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
#include "minddata/dataset/core/data_type.h"
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/engine/gnn/edge.h"
|
||||
#include "minddata/dataset/engine/gnn/feature.h"
|
||||
#include "minddata/dataset/engine/gnn/graph_feature_parser.h"
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
|
||||
#endif
|
||||
#include "minddata/dataset/engine/gnn/node.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
|
||||
#include "minddata/dataset/engine/gnn/graph_loader.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
class GraphLoaderFromArray : public GraphLoader {
|
||||
public:
|
||||
// Create graph with loading numpy array.
|
||||
GraphLoaderFromArray(GraphDataImpl *graph_impl, int32_t num_nodes, const std::shared_ptr<Tensor> &edge,
|
||||
const std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> &node_feat,
|
||||
const std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> &edge_feat,
|
||||
const std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> &graph_feat,
|
||||
const std::shared_ptr<Tensor> &node_type, const std::shared_ptr<Tensor> &edge_type,
|
||||
int32_t num_workers = 4, bool server_mode = false);
|
||||
|
||||
/// \brief default destructor
|
||||
~GraphLoaderFromArray() = default;
|
||||
|
||||
// Init array and load everything into memory multi-threaded
|
||||
// @return Status - the status code
|
||||
Status InitAndLoad() override;
|
||||
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
// load feature into shared memory
|
||||
// @param int32_t i - feature index
|
||||
// @param std::pair<int16_t, std::shared_ptr<Tensor>> item - contain feature type and feature value
|
||||
// @param std::shared_ptr<Tensor> *out_tensor, Tensor that convert from corresponding feature
|
||||
// @return Status - the status code
|
||||
Status LoadFeatureToSharedMemory(int32_t i, std::pair<int16_t, std::shared_ptr<Tensor>> item,
|
||||
std::shared_ptr<Tensor> *out_tensor);
|
||||
#endif
|
||||
|
||||
// load feature item
|
||||
// @param int32_t i - feature index
|
||||
// @param std::pair<int16_t, std::shared_ptr<Tensor>> item - contain feature type and feature value
|
||||
// @param std::shared_ptr<Tensor> *tensor, Tensor that convert from corresponding feature
|
||||
// @return Status - the status code
|
||||
Status LoadFeatureTensor(int32_t i, std::pair<int16_t, std::shared_ptr<Tensor>> item,
|
||||
std::shared_ptr<Tensor> *tensor);
|
||||
|
||||
private:
|
||||
// worker thread that reads array data
|
||||
// @param int32_t worker_id - id of each worker
|
||||
// @return Status - the status code
|
||||
Status WorkerEntry(int32_t worker_id);
|
||||
|
||||
// Load node into memory, returns a shared_ptr<Node>
|
||||
// @return Status - the status code
|
||||
Status LoadNode(int32_t worker_id);
|
||||
|
||||
// Load edge into memory, returns a shared_ptr<Edge>
|
||||
// @return Status - the status code
|
||||
Status LoadEdge(int32_t worker_id);
|
||||
|
||||
int32_t num_nodes_;
|
||||
const int32_t num_workers_;
|
||||
std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> node_feat_;
|
||||
std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> edge_feat_;
|
||||
std::unordered_map<std::int16_t, std::shared_ptr<Tensor>> graph_feat_;
|
||||
std::shared_ptr<Tensor> edge_ = nullptr;
|
||||
std::shared_ptr<Tensor> node_type_ = nullptr;
|
||||
std::shared_ptr<Tensor> edge_type_ = nullptr;
|
||||
std::shared_ptr<Tensor> node_weight_ = nullptr;
|
||||
std::shared_ptr<Tensor> edge_weight_ = nullptr;
|
||||
};
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_LOADER_ARRAY_H_
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -21,12 +21,13 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
LocalEdge::LocalEdge(EdgeIdType id, EdgeType type, WeightType weight, std::shared_ptr<Node> src_node,
|
||||
std::shared_ptr<Node> dst_node)
|
||||
: Edge(id, type, weight, src_node, dst_node) {}
|
||||
LocalEdge::LocalEdge(EdgeIdType id, EdgeType type, WeightType weight, NodeIdType src_id, NodeIdType dst_id)
|
||||
: Edge(id, type, weight, src_id, dst_id) {}
|
||||
|
||||
Status LocalEdge::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) {
|
||||
auto itr = features_.find(feature_type);
|
||||
auto itr = std::find_if(
|
||||
features_.begin(), features_.end(),
|
||||
[feature_type](std::pair<FeatureType, std::shared_ptr<Feature>> item) { return item.first == feature_type; });
|
||||
if (itr != features_.end()) {
|
||||
*out_feature = itr->second;
|
||||
return Status::OK();
|
||||
|
@ -37,11 +38,14 @@ Status LocalEdge::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature>
|
|||
}
|
||||
|
||||
Status LocalEdge::UpdateFeature(const std::shared_ptr<Feature> &feature) {
|
||||
auto itr = features_.find(feature->type());
|
||||
auto itr = std::find_if(
|
||||
features_.begin(), features_.end(),
|
||||
[feature](std::pair<FeatureType, std::shared_ptr<Feature>> item) { return item.first == feature->type(); });
|
||||
|
||||
if (itr != features_.end()) {
|
||||
RETURN_STATUS_UNEXPECTED("Feature already exists");
|
||||
} else {
|
||||
features_[feature->type()] = feature;
|
||||
features_.emplace_back(std::make_pair(feature->type(), feature));
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -19,6 +19,7 @@
|
|||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/gnn/edge.h"
|
||||
#include "minddata/dataset/engine/gnn/feature.h"
|
||||
|
@ -35,10 +36,9 @@ class LocalEdge : public Edge {
|
|||
// @param EdgeIdType id - edge id
|
||||
// @param EdgeType type - edge type
|
||||
// @param WeightType weight - edge weight
|
||||
// @param std::shared_ptr<Node> src_node - source node
|
||||
// @param std::shared_ptr<Node> dst_node - destination node
|
||||
LocalEdge(EdgeIdType id, EdgeType type, WeightType weight, std::shared_ptr<Node> src_node,
|
||||
std::shared_ptr<Node> dst_node);
|
||||
// @param NodeIdType src_id - source node id
|
||||
// @param NodeIdType dst_id - destination node id
|
||||
LocalEdge(EdgeIdType id, EdgeType type, WeightType weight, NodeIdType src_id, NodeIdType dst_id);
|
||||
|
||||
~LocalEdge() = default;
|
||||
|
||||
|
@ -54,7 +54,7 @@ class LocalEdge : public Edge {
|
|||
Status UpdateFeature(const std::shared_ptr<Feature> &feature) override;
|
||||
|
||||
private:
|
||||
std::unordered_map<FeatureType, std::shared_ptr<Feature>> features_;
|
||||
std::vector<std::pair<FeatureType, std::shared_ptr<Feature>>> features_;
|
||||
};
|
||||
} // namespace gnn
|
||||
} // namespace dataset
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -21,19 +21,17 @@
|
|||
#include <utility>
|
||||
|
||||
#include "minddata/dataset/engine/gnn/edge.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
LocalNode::LocalNode(NodeIdType id, NodeType type, WeightType weight)
|
||||
: Node(id, type, weight), rnd_(GetRandomDevice()) {
|
||||
rnd_.seed(GetSeed());
|
||||
}
|
||||
LocalNode::LocalNode(NodeIdType id, NodeType type, WeightType weight) : Node(id, type, weight) {}
|
||||
|
||||
Status LocalNode::GetFeatures(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) {
|
||||
auto itr = features_.find(feature_type);
|
||||
auto itr = std::find_if(
|
||||
features_.begin(), features_.end(),
|
||||
[feature_type](std::pair<FeatureType, std::shared_ptr<Feature>> item) { return item.first == feature_type; });
|
||||
if (itr != features_.end()) {
|
||||
*out_feature = itr->second;
|
||||
return Status::OK();
|
||||
|
@ -68,10 +66,10 @@ Status LocalNode::GetAllNeighbors(NodeType neighbor_type, std::vector<NodeIdType
|
|||
}
|
||||
|
||||
Status LocalNode::GetRandomSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num,
|
||||
std::vector<NodeIdType> *out) {
|
||||
std::vector<NodeIdType> *out, std::mt19937 *rnd) {
|
||||
std::vector<NodeIdType> shuffled_id(neighbors.size());
|
||||
std::iota(shuffled_id.begin(), shuffled_id.end(), 0);
|
||||
std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_);
|
||||
std::shuffle(shuffled_id.begin(), shuffled_id.end(), *rnd);
|
||||
int32_t num = std::min(samples_num, static_cast<int32_t>(neighbors.size()));
|
||||
for (int32_t i = 0; i < num; ++i) {
|
||||
out->emplace_back(neighbors[shuffled_id[i]]->id());
|
||||
|
@ -81,29 +79,29 @@ Status LocalNode::GetRandomSampledNeighbors(const std::vector<std::shared_ptr<No
|
|||
|
||||
Status LocalNode::GetWeightSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors,
|
||||
const std::vector<WeightType> &weights, int32_t samples_num,
|
||||
std::vector<NodeIdType> *out) {
|
||||
std::vector<NodeIdType> *out, std::mt19937 *rnd) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(neighbors.size() == weights.size(),
|
||||
"The number of neighbors does not match the weight.");
|
||||
std::discrete_distribution<NodeIdType> discrete_dist(weights.begin(), weights.end());
|
||||
for (int32_t i = 0; i < samples_num; ++i) {
|
||||
NodeIdType index = discrete_dist(rnd_);
|
||||
NodeIdType index = discrete_dist(*rnd);
|
||||
out->emplace_back(neighbors[index]->id());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LocalNode::GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, SamplingStrategy strategy,
|
||||
std::vector<NodeIdType> *out_neighbors) {
|
||||
std::vector<NodeIdType> *out_neighbors, std::mt19937 *rnd) {
|
||||
std::vector<NodeIdType> neighbors;
|
||||
neighbors.reserve(samples_num);
|
||||
auto itr = neighbor_nodes_.find(neighbor_type);
|
||||
if (itr != neighbor_nodes_.end()) {
|
||||
if (strategy == SamplingStrategy::kRandom) {
|
||||
while (neighbors.size() < samples_num) {
|
||||
RETURN_IF_NOT_OK(GetRandomSampledNeighbors(itr->second.first, samples_num - neighbors.size(), &neighbors));
|
||||
RETURN_IF_NOT_OK(GetRandomSampledNeighbors(itr->second.first, samples_num - neighbors.size(), &neighbors, rnd));
|
||||
}
|
||||
} else if (strategy == SamplingStrategy::kEdgeWeight) {
|
||||
RETURN_IF_NOT_OK(GetWeightSampledNeighbors(itr->second.first, itr->second.second, samples_num, &neighbors));
|
||||
RETURN_IF_NOT_OK(GetWeightSampledNeighbors(itr->second.first, itr->second.second, samples_num, &neighbors, rnd));
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Invalid strategy");
|
||||
}
|
||||
|
@ -151,11 +149,13 @@ Status LocalNode::GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType *
|
|||
}
|
||||
|
||||
Status LocalNode::UpdateFeature(const std::shared_ptr<Feature> &feature) {
|
||||
auto itr = features_.find(feature->type());
|
||||
auto itr = std::find_if(
|
||||
features_.begin(), features_.end(),
|
||||
[feature](std::pair<FeatureType, std::shared_ptr<Feature>> item) { return item.first == feature->type(); });
|
||||
if (itr != features_.end()) {
|
||||
RETURN_STATUS_UNEXPECTED("Feature already exists");
|
||||
} else {
|
||||
features_[feature->type()] = feature;
|
||||
features_.emplace_back(std::make_pair(feature->type(), feature));
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -58,7 +58,7 @@ class LocalNode : public Node {
|
|||
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
|
||||
// @return Status The status code returned
|
||||
Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, SamplingStrategy strategy,
|
||||
std::vector<NodeIdType> *out_neighbors) override;
|
||||
std::vector<NodeIdType> *out_neighbors, std::mt19937 *rnd) override;
|
||||
|
||||
// Add neighbor of node
|
||||
// @param std::shared_ptr<Node> node -
|
||||
|
@ -78,20 +78,20 @@ class LocalNode : public Node {
|
|||
Status GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType *out_edge_id) override;
|
||||
|
||||
// Update feature of node
|
||||
// @param std::shared_ptr<Feature> feature -
|
||||
// @param std::shared_ptr<Feature> feature
|
||||
// @return Status The status code returned
|
||||
Status UpdateFeature(const std::shared_ptr<Feature> &feature) override;
|
||||
|
||||
private:
|
||||
Status GetRandomSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors, int32_t samples_num,
|
||||
std::vector<NodeIdType> *out);
|
||||
std::vector<NodeIdType> *out, std::mt19937 *rnd);
|
||||
|
||||
Status GetWeightSampledNeighbors(const std::vector<std::shared_ptr<Node>> &neighbors,
|
||||
const std::vector<WeightType> &weights, int32_t samples_num,
|
||||
std::vector<NodeIdType> *out);
|
||||
std::vector<NodeIdType> *out, std::mt19937 *rnd);
|
||||
|
||||
std::mt19937 rnd_;
|
||||
std::unordered_map<FeatureType, std::shared_ptr<Feature>> features_;
|
||||
uint32_t rnd_seed_;
|
||||
std::vector<std::pair<FeatureType, std::shared_ptr<Feature>>> features_;
|
||||
std::unordered_map<NodeType, std::pair<std::vector<std::shared_ptr<Node>>, std::vector<WeightType>>> neighbor_nodes_;
|
||||
std::unordered_map<NodeIdType, EdgeIdType> adjacent_nodes_;
|
||||
};
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -21,6 +21,7 @@
|
|||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/gnn/feature.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -74,10 +75,10 @@ class Node {
|
|||
// @param std::vector<NodeIdType> *out_neighbors - Returned neighbors id
|
||||
// @return Status The status code returned
|
||||
virtual Status GetSampledNeighbors(NodeType neighbor_type, int32_t samples_num, SamplingStrategy strategy,
|
||||
std::vector<NodeIdType> *out_neighbors) = 0;
|
||||
std::vector<NodeIdType> *out_neighbors, std::mt19937 *rnd) = 0;
|
||||
|
||||
// Add neighbor of node
|
||||
// @param std::shared_ptr<Node> node -
|
||||
// @param std::shared_ptr<Node> node
|
||||
// @return Status The status code returned
|
||||
virtual Status AddNeighbor(const std::shared_ptr<Node> &node, const WeightType &weight) = 0;
|
||||
|
||||
|
|
|
@ -513,14 +513,15 @@ def type_check(arg, types, arg_name):
|
|||
Returns:
|
||||
Exception: when the validation fails, otherwise nothing.
|
||||
"""
|
||||
# handle special case of booleans being a subclass of ints
|
||||
print_value = '\"\"' if repr(arg) == repr('') else arg
|
||||
|
||||
if int in types and bool not in types:
|
||||
if isinstance(arg, bool):
|
||||
# handle special case of booleans being a subclass of ints
|
||||
print_value = '\"\"' if repr(arg) == repr('') else arg
|
||||
raise TypeError("Argument {0} with value {1} is not of type {2}, but got {3}.".format(arg_name, print_value,
|
||||
types, type(arg)))
|
||||
if not isinstance(arg, types):
|
||||
print_value = '\"\"' if repr(arg) == repr('') else arg
|
||||
raise TypeError("Argument {0} with value {1} is not of type {2}, but got {3}.".format(arg_name, print_value,
|
||||
list(types), type(arg)))
|
||||
|
||||
|
@ -719,13 +720,14 @@ def check_gnn_list_of_pair_or_ndarray(param, param_name):
|
|||
param_name, param.dtype))
|
||||
|
||||
|
||||
def check_gnn_list_or_ndarray(param, param_name):
|
||||
def check_gnn_list_or_ndarray(param, param_name, data_type=int):
|
||||
"""
|
||||
Check if the input parameter is list or numpy.ndarray.
|
||||
|
||||
Args:
|
||||
param (Union[list, nd.ndarray]): param.
|
||||
param_name (str): param_name.
|
||||
data_type(object): data type.
|
||||
|
||||
Returns:
|
||||
Exception: TypeError if error.
|
||||
|
@ -734,7 +736,7 @@ def check_gnn_list_or_ndarray(param, param_name):
|
|||
type_check(param, (list, np.ndarray), param_name)
|
||||
if isinstance(param, list):
|
||||
param_names = ["param_{0}".format(i) for i in range(len(param))]
|
||||
type_check_list(param, (int,), param_names)
|
||||
type_check_list(param, (data_type,), param_names)
|
||||
|
||||
elif isinstance(param, np.ndarray):
|
||||
if not param.dtype == np.int32:
|
||||
|
@ -802,3 +804,18 @@ def deprecator_factory(version, old_module, new_module, substitute_name=None, su
|
|||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def check_dict(data, key_type, value_type, param_name):
|
||||
""" check key and value type in dict."""
|
||||
if data is not None:
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError("{0} should be dict type, but got: {1}".format(param_name, type(data)))
|
||||
|
||||
for key, value in data.items():
|
||||
if not isinstance(key, key_type):
|
||||
raise TypeError("key '{0}' in parameter {1} should be {2} type, but got: {3}"
|
||||
.format(key, param_name, key_type, type(key)))
|
||||
if not isinstance(value, value_type):
|
||||
raise TypeError("value of '{0}' in parameter {1} should be {2} type, but got: {3}"
|
||||
.format(key, param_name, value_type, type(value)))
|
||||
|
|
|
@ -31,7 +31,7 @@ from .datasets_text import *
|
|||
from .datasets_audio import *
|
||||
from .datasets_standard_format import *
|
||||
from .datasets_user_defined import *
|
||||
from .graphdata import GraphData, SamplingStrategy, OutputFormat
|
||||
from .graphdata import GraphData, Graph, InMemoryGraphDataset, ArgoverseDataset, SamplingStrategy, OutputFormat
|
||||
from .iterators import *
|
||||
from .obs.obs_mindrecord_dataset import *
|
||||
from .samplers import *
|
||||
|
@ -103,6 +103,9 @@ __all__ = ["Caltech101Dataset", # Vision
|
|||
"NumpySlicesDataset", # User Defined
|
||||
"PaddedDataset", # User Defined
|
||||
"GraphData", # Graph Data
|
||||
"Graph", # Graph
|
||||
"InMemoryGraphDataset", # InMemoryGraphDataset
|
||||
"ArgoverseDataset", # ArgoverseDataset
|
||||
"DistributedSampler", # Sampler
|
||||
"RandomSampler", # Sampler
|
||||
"SequentialSampler", # Sampler
|
||||
|
|
|
@ -1,25 +1,28 @@
|
|||
#Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0(the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
# Licensed under the Apache License, Version 2.0(the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
#http: // www.apache.org/licenses/LICENSE-2.0
|
||||
# http: // www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
#== == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == ==
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == ==
|
||||
"""
|
||||
graphdata.py supports loading graph dataset for GNN network training,
|
||||
and provides operations related to graph data.
|
||||
"""
|
||||
import atexit
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from enum import IntEnum
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from mindspore._c_dataengine import GraphDataClient
|
||||
from mindspore._c_dataengine import GraphDataServer
|
||||
from mindspore._c_dataengine import Tensor
|
||||
|
@ -29,7 +32,9 @@ from mindspore._c_dataengine import OutputFormat as Format
|
|||
from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \
|
||||
check_gnn_get_nodes_from_edges, check_gnn_get_edges_from_nodes, check_gnn_get_all_neighbors, \
|
||||
check_gnn_get_sampled_neighbors, check_gnn_get_neg_sampled_neighbors, check_gnn_get_node_feature, \
|
||||
check_gnn_get_edge_feature, check_gnn_random_walk
|
||||
check_gnn_get_edge_feature, check_gnn_random_walk, check_gnn_graph, check_gnn_get_graph_feature
|
||||
from ..core.validator_helpers import replace_none
|
||||
from .datasets_user_defined import GeneratorDataset
|
||||
|
||||
|
||||
class SamplingStrategy(IntEnum):
|
||||
|
@ -126,26 +131,29 @@ class GraphData:
|
|||
num_client=1, auto_shutdown=True):
|
||||
self._dataset_file = dataset_file
|
||||
self._working_mode = working_mode
|
||||
self.data_format = "mindrecord"
|
||||
if num_parallel_workers is None:
|
||||
num_parallel_workers = 1
|
||||
|
||||
def stop():
|
||||
self._graph_data.stop()
|
||||
|
||||
if working_mode in ['local', 'client']:
|
||||
self._graph_data = GraphDataClient(dataset_file, num_parallel_workers, working_mode, hostname, port)
|
||||
atexit.register(stop)
|
||||
self._graph_data = GraphDataClient(self.data_format, dataset_file, num_parallel_workers, working_mode,
|
||||
hostname, port)
|
||||
atexit.register(self.stop)
|
||||
|
||||
if working_mode == 'server':
|
||||
self._graph_data = GraphDataServer(
|
||||
dataset_file, num_parallel_workers, hostname, port, num_client, auto_shutdown)
|
||||
atexit.register(stop)
|
||||
self.data_format, dataset_file, num_parallel_workers, hostname, port, num_client, auto_shutdown)
|
||||
atexit.register(self.stop)
|
||||
try:
|
||||
while self._graph_data.is_stopped() is not True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
raise Exception("Graph data server receives KeyboardInterrupt.")
|
||||
|
||||
def stop(self):
|
||||
"""Stop GraphDataClient or GraphDataServer."""
|
||||
self._graph_data.stop()
|
||||
|
||||
@check_gnn_get_all_nodes
|
||||
def get_all_nodes(self, node_type):
|
||||
"""
|
||||
|
@ -320,7 +328,7 @@ class GraphData:
|
|||
|
||||
Args:
|
||||
node_list (Union[list, numpy.ndarray]): The given list of nodes.
|
||||
neighbor_type (int): Specify the type of neighbor.
|
||||
neighbor_type (int): Specify the type of neighbor node.
|
||||
output_format (OutputFormat, optional): Output storage format (default=OutputFormat.NORMAL)
|
||||
It can be any of [OutputFormat.NORMAL, OutputFormat.COO, OutputFormat.CSR].
|
||||
|
||||
|
@ -392,7 +400,7 @@ class GraphData:
|
|||
if self._working_mode == 'server':
|
||||
raise Exception("This method is not supported when working mode is server.")
|
||||
return self._graph_data.get_sampled_neighbors(
|
||||
node_list, neighbor_nums, neighbor_types, DE_C_INTER_SAMPLING_STRATEGY[strategy]).as_array()
|
||||
node_list, neighbor_nums, neighbor_types, DE_C_INTER_SAMPLING_STRATEGY.get(strategy)).as_array()
|
||||
|
||||
@check_gnn_get_neg_sampled_neighbors
|
||||
def get_neg_sampled_neighbors(self, node_list, neg_neighbor_num, neg_neighbor_type):
|
||||
|
@ -521,3 +529,859 @@ class GraphData:
|
|||
raise Exception("This method is not supported when working mode is server.")
|
||||
return self._graph_data.random_walk(target_nodes, meta_path, step_home_param, step_away_param,
|
||||
default_node).as_array()
|
||||
|
||||
|
||||
class Graph(GraphData):
|
||||
"""
|
||||
A graph object for storing Graph structure and feature data.
|
||||
|
||||
This class supports init graph With input numpy array data, which represent edge, node and its features.
|
||||
If working mode is `local`, there is no need to specify input arguments like `working_mode`, `hostname`, `port`,
|
||||
`num_client`, `auto_shutdown`.
|
||||
|
||||
Args:
|
||||
edges(Union[list, numpy.ndarray]): edges of graph in COO format with shape [2, num_edges].
|
||||
node_feat(dict, optional): feature of nodes, key is feature type, value should be numpy.array with shape
|
||||
[num_nodes, num_node_features], feature type should be string, like 'weight' etc.
|
||||
edge_feat(dict, optional): feature of edges, key is feature type, value should be numpy.array with shape
|
||||
[num_edges, num_edge_features], feature type should be string, like 'weight' etc.
|
||||
graph_feat(dict, optional): additional feature, which can not be assigned to node_feat or edge_feat, key is
|
||||
feature type, value should be numpy.array.
|
||||
node_type(Union[list, numpy.ndarray], optional): type of nodes, each element should be string which represent
|
||||
type of corresponding node. If not provided, default type for each node is '0'.
|
||||
edge_type(Union[list, numpy.ndarray], optional): type of edges, each element should be string which represent
|
||||
type of corresponding edge. If not provided, default type for each edge is '0'.
|
||||
num_parallel_workers (int, optional): Number of workers to process the dataset in parallel (default=None).
|
||||
working_mode (str, optional): Set working mode, now supports 'local'/'client'/'server' (default='local').
|
||||
|
||||
- 'local', used in non-distributed training scenarios.
|
||||
|
||||
- 'client', used in distributed training scenarios. The client does not load data,
|
||||
but obtains data from the server.
|
||||
|
||||
- 'server', used in distributed training scenarios. The server loads the data
|
||||
and is available to the client.
|
||||
|
||||
hostname (str, optional): Hostname of the graph data server. This parameter is only valid when
|
||||
working_mode is set to 'client' or 'server' (default='127.0.0.1').
|
||||
port (int, optional): Port of the graph data server. The range is 1024-65535. This parameter is
|
||||
only valid when working_mode is set to 'client' or 'server' (default=50051).
|
||||
num_client (int, optional): Maximum number of clients expected to connect to the server. The server will
|
||||
allocate resources according to this parameter. This parameter is only valid when working_mode
|
||||
is set to 'server' (default=1).
|
||||
auto_shutdown (bool, optional): Valid when working_mode is set to 'server',
|
||||
when the number of connected clients reaches num_client and no client is being connected,
|
||||
the server automatically exits (default=True).
|
||||
|
||||
Examples:
|
||||
>> # 1) Only provide edges for creating graph, as this is the only required input parameter
|
||||
>> edges = np.array([[1, 2], [0, 1]], dtype=np.int32)
|
||||
>> g = Graph(edges)
|
||||
>> graph_info = g.graph_info()
|
||||
>>
|
||||
>> # 2) Setting node_feat and edge_feat for corresponding node and edge
|
||||
>> # first dimension of feature shape should be corrsponding node num or edge num.
|
||||
>> edges = np.array([[1, 2], [0, 1]], dtype=np.int32)
|
||||
>> node_feat = {"node_feature_1": np.array([[0], [1], [2]], dtype=np.int32)}
|
||||
>> edge_feat = {"edge_feature_1": np.array([[1, 2], [3, 4]], dtype=np.int32)}
|
||||
>> g = Graph(edges, node_feat, edge_feat)
|
||||
>>
|
||||
>> # 3) Setting graph feature for graph, there is shape limit for graph feature
|
||||
>> edges = np.array([[1, 2], [0, 1]], dtype=np.int32)
|
||||
>> graph_feature = {"graph_feature_1": np.array([1, 2, 3, 4, 5, 6], dtype=np.int32)}
|
||||
>> g = Graph(edges, graph_feat=graph_feature)
|
||||
"""
|
||||
|
||||
@check_gnn_graph
|
||||
def __init__(self, edges, node_feat=None, edge_feat=None, graph_feat=None, node_type=None, edge_type=None,
|
||||
num_parallel_workers=None, working_mode='local', hostname='127.0.0.1', port=50051, num_client=1,
|
||||
auto_shutdown=True):
|
||||
node_feat = replace_none(node_feat, {})
|
||||
edge_feat = replace_none(edge_feat, {})
|
||||
graph_feat = replace_none(graph_feat, {})
|
||||
edges = np.array(edges, dtype=np.int32)
|
||||
# infer num_nodes
|
||||
num_nodes = np.max(edges) + 1
|
||||
if node_feat != dict():
|
||||
num_nodes = node_feat.get(list(node_feat.keys())[0]).shape[0]
|
||||
|
||||
node_type = replace_none(node_type, np.array(['0'] * num_nodes))
|
||||
node_type = np.array(node_type)
|
||||
edge_type = replace_none(edge_type, np.array(['0'] * edges.shape[1]))
|
||||
edge_type = np.array(edge_type)
|
||||
|
||||
self._working_mode = working_mode
|
||||
self.data_format = "array"
|
||||
self.node_type_mapping, self.edge_type_mapping = dict(), dict()
|
||||
self.node_feature_type_mapping, self.edge_feature_type_mapping = dict(), dict()
|
||||
self.graph_feature_type_mapping = dict()
|
||||
self.invert_node_type_mapping, self.invert_edge_type_mapping = dict(), dict()
|
||||
self.invert_node_feature_type_mapping, self.invert_edge_feature_type_mapping = dict(), dict()
|
||||
self.invert_graph_feature_type_mapping = dict()
|
||||
|
||||
node_feat, edge_feat, graph_feat, node_type, edge_type = \
|
||||
self._replace_string(node_feat, edge_feat, graph_feat, node_type, edge_type)
|
||||
|
||||
if num_parallel_workers is None:
|
||||
num_parallel_workers = 1
|
||||
|
||||
if working_mode in ['local', 'client']:
|
||||
# GraphDataClient should support different init way, as data might be different
|
||||
self._graph_data = GraphDataClient(self.data_format, num_nodes, edges, node_feat, edge_feat, graph_feat,
|
||||
node_type, edge_type, num_parallel_workers, working_mode, hostname,
|
||||
port)
|
||||
atexit.register(self.stop)
|
||||
|
||||
if working_mode == 'server':
|
||||
self._graph_data = GraphDataServer(self.data_format, num_nodes, edges, node_feat, edge_feat, graph_feat,
|
||||
node_type, edge_type, num_parallel_workers, hostname, port, num_client,
|
||||
auto_shutdown)
|
||||
atexit.register(self.stop)
|
||||
try:
|
||||
while self._graph_data.is_stopped() is not True:
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
raise Exception("Graph data server receives KeyboardInterrupt.")
|
||||
|
||||
def stop(self):
|
||||
"""Stop GraphDataClient or GraphDataServer."""
|
||||
self._graph_data.stop()
|
||||
|
||||
@check_gnn_get_all_nodes
|
||||
def get_all_nodes(self, node_type):
|
||||
"""
|
||||
Get all nodes in the graph.
|
||||
|
||||
Args:
|
||||
node_type (str): Specify the type of node.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray, array of nodes.
|
||||
|
||||
Examples:
|
||||
>>> nodes = graph_dataset.get_all_nodes(node_type="0")
|
||||
|
||||
Raises:
|
||||
TypeError: If `node_type` is not string.
|
||||
"""
|
||||
if self._working_mode == 'server':
|
||||
raise Exception("This method is not supported when working mode is server.")
|
||||
|
||||
if node_type not in self.node_type_mapping:
|
||||
raise ValueError("Given node type {} is not exist in graph, existed is: {}."
|
||||
.format(node_type, list(self.node_type_mapping.keys())))
|
||||
node_int_type = self.node_type_mapping[node_type]
|
||||
return self._graph_data.get_all_nodes(node_int_type).as_array()
|
||||
|
||||
@check_gnn_get_all_edges
|
||||
def get_all_edges(self, edge_type):
|
||||
"""
|
||||
Get all edges in the graph.
|
||||
|
||||
Args:
|
||||
edge_type (int): Specify the type of edge.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray, array of edges.
|
||||
|
||||
Examples:
|
||||
>>> edges = graph_dataset.get_all_edges(edge_type='0')
|
||||
|
||||
Raises:
|
||||
TypeError: If `edge_type` is not string.
|
||||
"""
|
||||
if self._working_mode == 'server':
|
||||
raise Exception("This method is not supported when working mode is server.")
|
||||
|
||||
if edge_type not in self.edge_type_mapping:
|
||||
raise ValueError("Given node type {} is not exist in graph, existed is: {}."
|
||||
.format(edge_type, list(self.edge_type_mapping.keys())))
|
||||
edge_int_type = self.node_type_mapping[edge_type]
|
||||
return self._graph_data.get_all_edges(edge_int_type).as_array()
|
||||
|
||||
@check_gnn_get_all_neighbors
|
||||
def get_all_neighbors(self, node_list, neighbor_type, output_format=OutputFormat.NORMAL):
|
||||
"""
|
||||
Get `neighbor_type` neighbors of the nodes in `node_list`.
|
||||
We try to use the following example to illustrate the definition of these formats. 1 represents connected
|
||||
between two nodes, and 0 represents not connected.
|
||||
|
||||
.. list-table:: Adjacent Matrix
|
||||
:widths: 20 20 20 20 20
|
||||
:header-rows: 1
|
||||
|
||||
* -
|
||||
- 0
|
||||
- 1
|
||||
- 2
|
||||
- 3
|
||||
* - 0
|
||||
- 0
|
||||
- 1
|
||||
- 0
|
||||
- 0
|
||||
* - 1
|
||||
- 0
|
||||
- 0
|
||||
- 1
|
||||
- 0
|
||||
* - 2
|
||||
- 1
|
||||
- 0
|
||||
- 0
|
||||
- 1
|
||||
* - 3
|
||||
- 1
|
||||
- 0
|
||||
- 0
|
||||
- 0
|
||||
|
||||
.. list-table:: Normal Format
|
||||
:widths: 20 20 20 20 20
|
||||
:header-rows: 1
|
||||
|
||||
* - src
|
||||
- 0
|
||||
- 1
|
||||
- 2
|
||||
- 3
|
||||
* - dst_0
|
||||
- 1
|
||||
- 2
|
||||
- 0
|
||||
- 1
|
||||
* - dst_1
|
||||
- -1
|
||||
- -1
|
||||
- 3
|
||||
- -1
|
||||
|
||||
.. list-table:: COO Format
|
||||
:widths: 20 20 20 20 20 20
|
||||
:header-rows: 1
|
||||
|
||||
* - src
|
||||
- 0
|
||||
- 1
|
||||
- 2
|
||||
- 2
|
||||
- 3
|
||||
* - dst
|
||||
- 1
|
||||
- 2
|
||||
- 0
|
||||
- 3
|
||||
- 1
|
||||
|
||||
.. list-table:: CSR Format
|
||||
:widths: 40 20 20 20 20 20
|
||||
:header-rows: 1
|
||||
|
||||
* - offsetTable
|
||||
- 0
|
||||
- 1
|
||||
- 2
|
||||
- 4
|
||||
-
|
||||
* - dstTable
|
||||
- 1
|
||||
- 2
|
||||
- 0
|
||||
- 3
|
||||
- 1
|
||||
|
||||
Args:
|
||||
node_list (Union[list, numpy.ndarray]): The given list of nodes.
|
||||
neighbor_type (str): Specify the type of neighbor node.
|
||||
output_format (OutputFormat, optional): Output storage format (default=OutputFormat.NORMAL)
|
||||
It can be any of [OutputFormat.NORMAL, OutputFormat.COO, OutputFormat.CSR].
|
||||
|
||||
Returns:
|
||||
For NORMAL format or COO format
|
||||
numpy.ndarray which represents the array of neighbors will return.
|
||||
As if CSR format is specified, two numpy.ndarrays will return.
|
||||
The first one is offset table, the second one is neighbors
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.dataset.engine import OutputFormat
|
||||
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
|
||||
>>> neighbors = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type='0')
|
||||
>>> neighbors_coo = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type='0',
|
||||
... output_format=OutputFormat.COO)
|
||||
>>> offset_table, neighbors_csr = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type='0',
|
||||
... output_format=OutputFormat.CSR)
|
||||
|
||||
Raises:
|
||||
TypeError: If `node_list` is not list or ndarray.
|
||||
TypeError: If `neighbor_type` is not string.
|
||||
"""
|
||||
if self._working_mode == 'server':
|
||||
raise Exception("This method is not supported when working mode is server.")
|
||||
if neighbor_type not in self.node_type_mapping:
|
||||
raise ValueError("Given neighbor node type {} is not exist in graph, existed is: {}."
|
||||
.format(neighbor_type, list(self.node_type_mapping.keys())))
|
||||
neighbor_int_type = self.node_type_mapping[neighbor_type]
|
||||
result_list = self._graph_data.get_all_neighbors(node_list, neighbor_int_type,
|
||||
DE_C_INTER_OUTPUT_FORMAT[output_format]).as_array()
|
||||
if output_format == OutputFormat.CSR:
|
||||
offset_table = result_list[:len(node_list)]
|
||||
neighbor_table = result_list[len(node_list):]
|
||||
return offset_table, neighbor_table
|
||||
return result_list
|
||||
|
||||
@check_gnn_get_sampled_neighbors
|
||||
def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types, strategy=SamplingStrategy.RANDOM):
|
||||
"""
|
||||
Get sampled neighbor information.
|
||||
|
||||
The api supports multi-hop neighbor sampling. That is, the previous sampling result is used as the input of
|
||||
next-hop sampling. A maximum of 6-hop are allowed.
|
||||
|
||||
The sampling result is tiled into a list in the format of [input node, 1-hop sampling result,
|
||||
2-hop sampling result ...]
|
||||
|
||||
Args:
|
||||
node_list (Union[list, numpy.ndarray]): The given list of nodes.
|
||||
neighbor_nums (Union[list, numpy.ndarray]): Number of neighbors sampled per hop.
|
||||
neighbor_types (Union[list, numpy.ndarray]): Neighbor type sampled per hop.
|
||||
strategy (SamplingStrategy, optional): Sampling strategy (default=SamplingStrategy.RANDOM).
|
||||
It can be any of [SamplingStrategy.RANDOM, SamplingStrategy.EDGE_WEIGHT].
|
||||
|
||||
- SamplingStrategy.RANDOM, random sampling with replacement.
|
||||
- SamplingStrategy.EDGE_WEIGHT, sampling with edge weight as probability.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray, array of neighbors.
|
||||
|
||||
Examples:
|
||||
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
|
||||
>>> neighbors = graph_dataset.get_sampled_neighbors(node_list=nodes, neighbor_nums=[2, 2],
|
||||
... neighbor_types=[2, 1])
|
||||
|
||||
Raises:
|
||||
TypeError: If `node_list` is not list or ndarray.
|
||||
TypeError: If `neighbor_nums` is not list or ndarray.
|
||||
TypeError: If `neighbor_types` is not list or ndarray.
|
||||
"""
|
||||
if not isinstance(strategy, SamplingStrategy):
|
||||
raise TypeError("Wrong input type for strategy, should be enum of 'SamplingStrategy'.")
|
||||
if self._working_mode == 'server':
|
||||
raise Exception("This method is not supported when working mode is server.")
|
||||
|
||||
neighbor_int_types = []
|
||||
for neighbor_type in neighbor_types:
|
||||
if neighbor_type not in self.node_type_mapping:
|
||||
raise ValueError("Given neighbor node type {} is not exist in graph, existed is: {}."
|
||||
.format(neighbor_type, list(self.node_type_mapping.keys())))
|
||||
neighbor_int_types.append(self.node_type_mapping[neighbor_type])
|
||||
return self._graph_data.get_sampled_neighbors(
|
||||
node_list, neighbor_nums, neighbor_int_types, DE_C_INTER_SAMPLING_STRATEGY.get(strategy)).as_array()
|
||||
|
||||
@check_gnn_get_neg_sampled_neighbors
|
||||
def get_neg_sampled_neighbors(self, node_list, neg_neighbor_num, neg_neighbor_type):
|
||||
"""
|
||||
Get `neg_neighbor_type` negative sampled neighbors of the nodes in `node_list`.
|
||||
|
||||
Args:
|
||||
node_list (Union[list, numpy.ndarray]): The given list of nodes.
|
||||
neg_neighbor_num (int): Number of neighbors sampled.
|
||||
neg_neighbor_type (str): Specify the type of negative neighbor.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray, array of neighbors.
|
||||
|
||||
Examples:
|
||||
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
|
||||
>>> neg_neighbors = graph_dataset.get_neg_sampled_neighbors(node_list=nodes, neg_neighbor_num=5,
|
||||
... neg_neighbor_type='0')
|
||||
|
||||
Raises:
|
||||
TypeError: If `node_list` is not list or ndarray.
|
||||
TypeError: If `neg_neighbor_num` is not integer.
|
||||
TypeError: If `neg_neighbor_type` is not string.
|
||||
"""
|
||||
if self._working_mode == 'server':
|
||||
raise Exception("This method is not supported when working mode is server.")
|
||||
if neg_neighbor_type not in self.node_type_mapping:
|
||||
raise ValueError("Given neighbor node type {} is not exist in graph, existed is: {}"
|
||||
.format(neg_neighbor_type, list(self.node_type_mapping.keys())))
|
||||
neg_neighbor_int_type = self.node_type_mapping[neg_neighbor_type]
|
||||
return self._graph_data.get_neg_sampled_neighbors(
|
||||
node_list, neg_neighbor_num, neg_neighbor_int_type).as_array()
|
||||
|
||||
@check_gnn_get_node_feature
|
||||
def get_node_feature(self, node_list, feature_types):
|
||||
"""
|
||||
Get `feature_types` feature of the nodes in `node_list`.
|
||||
|
||||
Args:
|
||||
node_list (Union[list, numpy.ndarray]): The given list of nodes.
|
||||
feature_types (Union[list, numpy.ndarray]): The given list of feature types, each element should be string.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray, array of features.
|
||||
|
||||
Examples:
|
||||
>>> nodes = graph_dataset.get_all_nodes(node_type='0')
|
||||
>>> features = graph_dataset.get_node_feature(node_list=nodes, feature_types=["feature_1", "feature_2"])
|
||||
|
||||
Raises:
|
||||
TypeError: If `node_list` is not list or ndarray.
|
||||
TypeError: If `feature_types` is not list or ndarray.
|
||||
"""
|
||||
if self._working_mode == 'server':
|
||||
raise Exception("This method is not supported when working mode is server.")
|
||||
|
||||
feature_int_types = []
|
||||
for feature_type in feature_types:
|
||||
if feature_type not in self.node_feature_type_mapping:
|
||||
raise ValueError("Given node feature type {} is not exist in graph, existed is: {}."
|
||||
.format(feature_type, list(self.node_feature_type_mapping.keys())))
|
||||
feature_int_types.append(self.node_feature_type_mapping[feature_type])
|
||||
if isinstance(node_list, list):
|
||||
node_list = np.array(node_list, dtype=np.int32)
|
||||
return [
|
||||
t.as_array() for t in self._graph_data.get_node_feature(
|
||||
Tensor(node_list),
|
||||
feature_int_types)]
|
||||
|
||||
@check_gnn_get_edge_feature
|
||||
def get_edge_feature(self, edge_list, feature_types):
|
||||
"""
|
||||
Get `feature_types` feature of the edges in `edge_list`.
|
||||
|
||||
Args:
|
||||
edge_list (Union[list, numpy.ndarray]): The given list of edges.
|
||||
feature_types (Union[list, numpy.ndarray]): The given list of feature types, each element should be string.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray, array of features.
|
||||
|
||||
Examples:
|
||||
>>> edges = graph_dataset.get_all_edges(edge_type='0')
|
||||
>>> features = graph_dataset.get_edge_feature(edge_list=edges, feature_types=["feature_1"])
|
||||
|
||||
Raises:
|
||||
TypeError: If `edge_list` is not list or ndarray.
|
||||
TypeError: If `feature_types` is not list or ndarray.
|
||||
"""
|
||||
if self._working_mode == 'server':
|
||||
raise Exception("This method is not supported when working mode is server.")
|
||||
feature_int_types = []
|
||||
for feature_type in feature_types:
|
||||
if feature_type not in self.edge_feature_type_mapping:
|
||||
raise ValueError("Given edge feature type {} is not exist in graph, existed is: {}."
|
||||
.format(feature_type, list(self.edge_feature_type_mapping.keys())))
|
||||
feature_int_types.append(self.edge_feature_type_mapping[feature_type])
|
||||
|
||||
if isinstance(edge_list, list):
|
||||
edge_list = np.array(edge_list, dtype=np.int32)
|
||||
return [
|
||||
t.as_array() for t in self._graph_data.get_edge_feature(
|
||||
Tensor(edge_list),
|
||||
feature_int_types)]
|
||||
|
||||
@check_gnn_get_graph_feature
|
||||
def get_graph_feature(self, feature_types):
|
||||
"""
|
||||
Get `feature_types` feature of the nodes in `node_list`.
|
||||
|
||||
Args:
|
||||
feature_types (Union[list, numpy.ndarray]): The given list of feature types, each element should be string.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray, array of features.
|
||||
|
||||
Examples:
|
||||
>>> features = graph_dataset.get_graph_feature(feature_types=['feature_1', 'feature_2'])
|
||||
|
||||
Raises:
|
||||
TypeError: If `node_list` is not list or ndarray.
|
||||
TypeError: If `feature_types` is not list or ndarray.
|
||||
"""
|
||||
if self._working_mode in ['server']:
|
||||
raise Exception("This method is not supported when working mode is server.")
|
||||
|
||||
feature_int_types = []
|
||||
for feature_type in feature_types:
|
||||
if feature_type not in self.graph_feature_type_mapping:
|
||||
raise ValueError("Given graph feature type {} is not exist in graph, existed is: {}."
|
||||
.format(feature_type, list(self.graph_feature_type_mapping.keys())))
|
||||
feature_int_types.append(self.graph_feature_type_mapping[feature_type])
|
||||
return [t.as_array() for t in self._graph_data.get_graph_feature(feature_int_types)]
|
||||
|
||||
@staticmethod
|
||||
def _convert_list(data, mapping):
|
||||
"""
|
||||
Convert list data according to given mapping.
|
||||
"""
|
||||
new_data = []
|
||||
for item in data:
|
||||
new_data.append(mapping[item])
|
||||
return new_data
|
||||
|
||||
@staticmethod
|
||||
def _convert_dict(data, mapping):
|
||||
"""
|
||||
Convert dict data according to given mapping.
|
||||
"""
|
||||
new_data = dict()
|
||||
for key, value in data.items():
|
||||
new_data[mapping[key]] = value
|
||||
return new_data
|
||||
|
||||
def graph_info(self):
|
||||
"""
|
||||
Get the meta information of the graph, including the number of nodes, the type of nodes,
|
||||
the feature information of nodes, the number of edges, the type of edges, and the feature information of edges.
|
||||
|
||||
Returns:
|
||||
dict, meta information of the graph. The key is node_type, edge_type, node_num, edge_num,
|
||||
node_feature_type and edge_feature_type.
|
||||
"""
|
||||
if self._working_mode == 'server':
|
||||
raise Exception("This method is not supported when working mode is server.")
|
||||
# do type convert for node_type, edge_type, and other feature_type
|
||||
raw_info = self._graph_data.graph_info()
|
||||
graph_info = dict()
|
||||
graph_info["node_type"] = self._convert_list(raw_info["node_type"], self.invert_node_type_mapping)
|
||||
graph_info["edge_type"] = self._convert_list(raw_info["edge_type"], self.invert_edge_type_mapping)
|
||||
graph_info["node_feature_type"] = \
|
||||
self._convert_list(raw_info["node_feature_type"], self.invert_node_feature_type_mapping)
|
||||
graph_info["edge_feature_type"] = \
|
||||
self._convert_list(raw_info["edge_feature_type"], self.invert_edge_feature_type_mapping)
|
||||
graph_info["graph_feature_type"] = \
|
||||
self._convert_list(raw_info["graph_feature_type"], self.invert_graph_feature_type_mapping)
|
||||
graph_info["node_num"] = self._convert_dict(raw_info["node_num"], self.invert_node_type_mapping)
|
||||
graph_info["edge_num"] = self._convert_dict(raw_info["edge_num"], self.invert_edge_type_mapping)
|
||||
return graph_info
|
||||
|
||||
def _replace_string(self, node_feat, edge_feat, graph_feat, node_type, edge_type):
|
||||
"""
|
||||
replace key in node_feat, edge_feat and graph_feat from string into int, and replace value in node_type and
|
||||
edge_type from string to int.
|
||||
"""
|
||||
|
||||
def replace_dict_key(feature):
|
||||
index = 0
|
||||
new_feature = dict()
|
||||
feature_type_mapping = dict()
|
||||
for item in feature.items():
|
||||
new_feature[index] = item[1]
|
||||
feature_type_mapping[item[0]] = index
|
||||
index += 1
|
||||
return new_feature, feature_type_mapping
|
||||
|
||||
def replace_value(data_type):
|
||||
index = 0
|
||||
feature_type_mapping = dict()
|
||||
node_type_set = np.unique(data_type)
|
||||
for item in node_type_set:
|
||||
data_type[data_type == item] = index
|
||||
feature_type_mapping[item] = index
|
||||
index += 1
|
||||
data_type = data_type.astype(np.int8)
|
||||
return data_type, feature_type_mapping
|
||||
|
||||
def invert_dict(mapping):
|
||||
new_mapping = dict()
|
||||
for key, value in mapping.items():
|
||||
new_mapping[value] = key
|
||||
return new_mapping
|
||||
|
||||
new_node_feat, self.node_feature_type_mapping = replace_dict_key(node_feat)
|
||||
new_edge_feat, self.edge_feature_type_mapping = replace_dict_key(edge_feat)
|
||||
new_graph_feat, self.graph_feature_type_mapping = replace_dict_key(graph_feat)
|
||||
new_node_type, self.node_type_mapping = replace_value(node_type)
|
||||
new_edge_type, self.edge_type_mapping = replace_value(edge_type)
|
||||
|
||||
self.invert_node_type_mapping = invert_dict(self.node_type_mapping)
|
||||
self.invert_edge_type_mapping = invert_dict(self.edge_type_mapping)
|
||||
self.invert_node_feature_type_mapping = invert_dict(self.node_feature_type_mapping)
|
||||
self.invert_edge_feature_type_mapping = invert_dict(self.edge_feature_type_mapping)
|
||||
self.invert_graph_feature_type_mapping = invert_dict(self.graph_feature_type_mapping)
|
||||
|
||||
return (new_node_feat, new_edge_feat, new_graph_feat, new_node_type, new_edge_type)
|
||||
|
||||
|
||||
def save_graphs(path, graph_list, num_graphs_per_file=1, data_format="numpy"):
|
||||
"""
|
||||
When init a graph, input parameter including: edges, node_feat, edge_feat, graph_feat, node_type, edge_type
|
||||
if do collate function, data of graph will be load into python layer
|
||||
but we consider to implement save graph in c++ layer, thus save to single graph_idx.npz firstly
|
||||
"""
|
||||
|
||||
def merge_into_dict(data, data_array, feature_type, prefix):
|
||||
for key, value in zip(feature_type, data_array):
|
||||
# shape each item should be [num_xxx, num_feature]
|
||||
data[prefix + str(key)] = value
|
||||
|
||||
graph_data = dict()
|
||||
pre_idx = 0
|
||||
graph_num = len(graph_list)
|
||||
for idx, graph in enumerate(graph_list):
|
||||
graph_info = graph.graph_info()
|
||||
# currently input args of get_all_edges can only be int not list.
|
||||
edge_ids = graph.get_all_edges(graph_info["edge_type"][0])
|
||||
edges = np.array(graph.get_nodes_from_edges(edge_ids)).transpose()
|
||||
graph_data["graph_" + str(idx) + "_edges"] = edges
|
||||
|
||||
# currently input args of get_all_nodes can only be int not list.
|
||||
node_ids = graph.get_all_nodes(graph_info["node_type"][0])
|
||||
if graph_info["node_feature_type"]:
|
||||
node_feat = graph.get_node_feature(node_ids, graph_info["node_feature_type"])
|
||||
merge_into_dict(graph_data, node_feat, graph_info["node_feature_type"], "graph_" + str(idx) + "_node_feat_")
|
||||
if graph_info["edge_feature_type"]:
|
||||
edge_feat = graph.get_edge_feature(edge_ids, graph_info["edge_feature_type"])
|
||||
merge_into_dict(graph_data, edge_feat, graph_info["edge_feature_type"], "graph_" + str(idx) + "_edge_feat_")
|
||||
if graph_info["graph_feature_type"]:
|
||||
graph_feat = graph.get_graph_feature(graph_info["graph_feature_type"])
|
||||
merge_into_dict(graph_data, graph_feat, graph_info["graph_feature_type"],
|
||||
"graph_" + str(idx) + "_graph_feat_")
|
||||
|
||||
# node_type and edge_type need to provide access interface, current unable to get
|
||||
if (idx + 1) % num_graphs_per_file == 0 or idx == (graph_num - 1):
|
||||
file_name = "graph_" + str(pre_idx) + "_" + str(idx) + ".npz"
|
||||
file_path = os.path.join(path, file_name)
|
||||
np.savez(file_path, **graph_data)
|
||||
graph_data = dict()
|
||||
pre_idx = idx + 1
|
||||
|
||||
|
||||
def load_graphs(path, data_format="numpy", num_parallel_workers=1):
|
||||
"""
|
||||
To be implemented in c++ layer, logic may similar as current implementation.
|
||||
"""
|
||||
# consider add param like in graph param: working_mode, num_client ...
|
||||
files = [os.path.join(path, file_name) for file_name in os.listdir(path)]
|
||||
sorted(files)
|
||||
|
||||
def get_feature_data(param_name, cols, graph_data):
|
||||
data_dict = dict()
|
||||
param_name = param_name + "_"
|
||||
for col in cols:
|
||||
if param_name in col:
|
||||
feature_type = col.split(param_name)[1]
|
||||
# reshape data with 2 dimension
|
||||
temp_data = graph_data[col]
|
||||
if len(temp_data.shape) == 1 and "graph_feat_" not in param_name:
|
||||
temp_data = temp_data.reshape(temp_data.shape[0], 1)
|
||||
data_dict[feature_type] = temp_data
|
||||
return data_dict
|
||||
|
||||
graphs = []
|
||||
for file in files:
|
||||
if not file.endswith("npz"):
|
||||
continue
|
||||
data = np.load(file)
|
||||
id_list = file.split("/")[-1].strip(".npz").split("_")
|
||||
ids = list(range(int(id_list[1]), int(id_list[2]) + 1))
|
||||
random.shuffle(ids)
|
||||
total_files = data.files
|
||||
for idx in ids:
|
||||
node_feat, edge_feat, graph_feat, node_type, edge_type = None, None, None, None, None
|
||||
keys = []
|
||||
prefix = "graph_" + str(idx) + "_"
|
||||
for item in total_files:
|
||||
if item.startswith(prefix):
|
||||
keys.append(item)
|
||||
|
||||
edges = data[prefix + "edges"]
|
||||
node_feat = get_feature_data(prefix + "node_feat", keys, data)
|
||||
edge_feat = get_feature_data(prefix + "edge_feat", keys, data)
|
||||
graph_feat = get_feature_data(prefix + "graph_feat", keys, data)
|
||||
|
||||
if "node_type" in keys:
|
||||
node_type = data[prefix + "node_type"]
|
||||
if "edge_type" in keys:
|
||||
edge_type = data[prefix + "edge_type"]
|
||||
|
||||
# consider graph been created in graph mode firstly
|
||||
graph = Graph(edges, node_feat, edge_feat, graph_feat, node_type, edge_type,
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
graphs.append(graph)
|
||||
return graphs
|
||||
|
||||
|
||||
class _UsersDatasetTemplate:
|
||||
"""
|
||||
Template for guiding user to create corresponding dataset(should inherit InMemoryGraphDataset when implemented).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __getitem__(self, item):
|
||||
pass
|
||||
|
||||
def __len__(self):
|
||||
return 1
|
||||
|
||||
def process(self):
|
||||
pass
|
||||
|
||||
|
||||
class InMemoryGraphDataset(GeneratorDataset):
|
||||
"""
|
||||
The basic Dataset for loading graph into memory.
|
||||
Recommended to inherit this class, and implement your own method like 'process', 'save' and 'load'.
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir, column_names="graph", save_dir="./processed", num_parallel_workers=1,
|
||||
shuffle=None, num_shards=None, shard_id=None, python_multiprocessing=True, max_rowsize=6):
|
||||
self.graphs = []
|
||||
self.data_dir = data_dir
|
||||
self.save_dir = save_dir
|
||||
self.processed_path = os.path.join(self.data_dir, self.save_dir)
|
||||
self.processed = False
|
||||
if 'process' in self.__class__.__dict__:
|
||||
self._process()
|
||||
if self.processed:
|
||||
self.load()
|
||||
|
||||
source = _UsersDatasetTemplate()
|
||||
for k, v in self.__dict__.items():
|
||||
setattr(source, k, v)
|
||||
for k, v in self.__class__.__dict__.items():
|
||||
setattr(source.__class__, k, getattr(self.__class__, k))
|
||||
super().__init__(source, column_names=column_names, num_parallel_workers=num_parallel_workers, shuffle=shuffle,
|
||||
num_shards=num_shards, shard_id=shard_id, python_multiprocessing=python_multiprocessing,
|
||||
max_rowsize=max_rowsize)
|
||||
|
||||
def process(self):
|
||||
"""
|
||||
Override this method in your our dataset class.
|
||||
"""
|
||||
raise NotImplementedError("'process' method should be implemented in your own logic.")
|
||||
|
||||
def save(self):
|
||||
"""
|
||||
Override this method in your our dataset class.
|
||||
"""
|
||||
save_graphs(self.processed_path, self.graphs)
|
||||
|
||||
def load(self):
|
||||
"""
|
||||
Override this method in your our dataset class.
|
||||
"""
|
||||
self.graphs = load_graphs(self.processed_path, num_parallel_workers=1)
|
||||
|
||||
def _process(self):
|
||||
# file has been processed and saved into processed_path
|
||||
if not os.path.isdir(self.processed_path):
|
||||
os.makedirs(self.processed_path, exist_ok=True)
|
||||
elif os.listdir(self.processed_path):
|
||||
self.processed = True
|
||||
return
|
||||
self.process()
|
||||
self.save()
|
||||
|
||||
|
||||
class ArgoverseDataset(InMemoryGraphDataset):
|
||||
"""
|
||||
Load argoverse dataset and create graph.
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir, column_names="graph", shuffle=None, num_parallel_workers=1,
|
||||
python_multiprocessing=True, perf_mode=True):
|
||||
# For high performance, here we store edge_index into graph_feature directly
|
||||
self.perf_mode = perf_mode
|
||||
super().__init__(data_dir, column_names, shuffle=shuffle, num_parallel_workers=num_parallel_workers,
|
||||
python_multiprocessing=python_multiprocessing)
|
||||
|
||||
def __getitem__(self, index):
|
||||
graph = self.graphs[index]
|
||||
if self.perf_mode:
|
||||
return graph.get_graph_feature(
|
||||
feature_types=["edge_index", "x", "y", "cluster", "valid_len", "time_step_len"])
|
||||
|
||||
graph_info = graph.graph_info()
|
||||
all_nodes = graph.get_all_nodes(graph_info["node_type"][0])
|
||||
edge_ids = graph.get_all_edges(graph_info["edge_type"][0])
|
||||
edge_index = np.array(graph.get_nodes_from_edges(edge_ids)).transpose()
|
||||
x = graph.get_node_feature(all_nodes, feature_types=["x"])[0]
|
||||
graph_feature = graph.get_graph_feature(feature_types=["y", "cluster", "valid_len", "time_step_len"])
|
||||
y, cluster, valid_len, time_step_len = graph_feature
|
||||
|
||||
return edge_index, x, y, cluster, valid_len, time_step_len
|
||||
|
||||
def __len__(self):
|
||||
return len(self.graphs)
|
||||
|
||||
def process(self):
|
||||
"""
|
||||
process method mainly refers to: https://github.com/xk-huang/yet-another-vectornet/blob/master/dataset.py
|
||||
"""
|
||||
|
||||
def get_edge_full_connection(node_num, start_index=0):
|
||||
"""
|
||||
Obtain edge_index with shape (2, edge_num)
|
||||
"""
|
||||
edges = np.empty((2, 0))
|
||||
end = np.arange(node_num, dtype=np.int64)
|
||||
for idx in range(node_num):
|
||||
begin = np.ones(node_num, dtype=np.int64) * idx
|
||||
edges = np.hstack((edges, np.vstack(
|
||||
(np.hstack([begin[:idx], begin[idx + 1:]]), np.hstack([end[:idx], end[idx + 1:]])))))
|
||||
edges = edges + start_index
|
||||
|
||||
return edges.astype(np.int64), node_num + start_index
|
||||
|
||||
file_path = [os.path.join(self.data_dir, file_name) for file_name in os.listdir(self.data_dir)]
|
||||
sorted(file_path)
|
||||
|
||||
valid_len_list = []
|
||||
data_list = []
|
||||
for data_p in file_path:
|
||||
if not data_p.endswith('pkl'):
|
||||
continue
|
||||
x_list, edge_index_list = [], []
|
||||
data = pd.read_pickle(data_p)
|
||||
input_features = data['POLYLINE_FEATURES'].values[0]
|
||||
basic_len = data['TARJ_LEN'].values[0]
|
||||
cluster = input_features[:, -1].reshape(-1).astype(np.int32)
|
||||
valid_len_list.append(cluster.max())
|
||||
y = data['GT'].values[0].reshape(-1).astype(np.float32)
|
||||
|
||||
traj_id_mask = data["TRAJ_ID_TO_MASK"].values[0]
|
||||
lane_id_mask = data['LANE_ID_TO_MASK'].values[0]
|
||||
start_idx = 0
|
||||
|
||||
for _, mask in traj_id_mask.items():
|
||||
feature = input_features[mask[0]:mask[1]]
|
||||
temp_edge, start_idx = get_edge_full_connection(
|
||||
feature.shape[0], start_idx)
|
||||
x_list.append(feature)
|
||||
edge_index_list.append(temp_edge)
|
||||
|
||||
for _, mask in lane_id_mask.items():
|
||||
feature = input_features[mask[0] + basic_len: mask[1] + basic_len]
|
||||
temp_edge, start_idx = get_edge_full_connection(
|
||||
feature.shape[0], start_idx)
|
||||
x_list.append(feature)
|
||||
edge_index_list.append(temp_edge)
|
||||
edge_index = np.hstack(edge_index_list)
|
||||
x = np.vstack(x_list)
|
||||
data_list.append([x, y, cluster, edge_index])
|
||||
|
||||
graphs = []
|
||||
pad_to_index = np.max(valid_len_list)
|
||||
feature_len = data_list[0][0].shape[1]
|
||||
for index, item in enumerate(data_list):
|
||||
item[0] = np.vstack(
|
||||
[item[0], np.zeros((pad_to_index - item[-2].max(), feature_len), dtype=item[0].dtype)])
|
||||
item[-2] = np.hstack(
|
||||
[item[2], np.arange(item[-2].max() + 1, pad_to_index + 1)])
|
||||
|
||||
if self.perf_mode:
|
||||
graph_feature = {"edge_index": item[3], "x": item[0], "y": item[1], "cluster": item[2],
|
||||
"valid_len": np.array([valid_len_list[index]]),
|
||||
"time_step_len": np.array([pad_to_index + 1])}
|
||||
g_data = Graph(edges=item[3], graph_feat=graph_feature)
|
||||
else:
|
||||
node_feature = {"x": item[0]}
|
||||
graph_feature = {"y": item[1], "cluster": item[2], "valid_len": np.array([valid_len_list[index]]),
|
||||
"time_step_len": np.array([pad_to_index + 1])}
|
||||
g_data = Graph(edges=item[3], node_feat=node_feature, graph_feat=graph_feature)
|
||||
graphs.append(g_data)
|
||||
self.graphs = graphs
|
||||
|
|
|
@ -28,7 +28,7 @@ from ..core.validator_helpers import parse_user_args, type_check, type_check_lis
|
|||
INT32_MAX, check_valid_detype, check_dir, check_file, check_sampler_shuffle_shard_options, \
|
||||
validate_dataset_param_value, check_padding_options, check_gnn_list_or_ndarray, check_gnn_list_of_pair_or_ndarray, \
|
||||
check_num_parallel_workers, check_columns, check_pos_int32, check_valid_str, check_dataset_num_shards_shard_id, \
|
||||
check_valid_list_tuple
|
||||
check_valid_list_tuple, check_dict
|
||||
|
||||
from . import datasets
|
||||
from . import samplers
|
||||
|
@ -1838,7 +1838,7 @@ def check_gnn_graphdata(method):
|
|||
check_num_parallel_workers(num_parallel_workers)
|
||||
type_check(hostname, (str,), "hostname")
|
||||
if check_hostname(hostname) is False:
|
||||
raise ValueError("The hostname is illegal")
|
||||
raise ValueError("The hostname is illegal.")
|
||||
type_check(working_mode, (str,), "working_mode")
|
||||
if working_mode not in {'local', 'client', 'server'}:
|
||||
raise ValueError("Invalid working mode, please enter 'local', 'client' or 'server'.")
|
||||
|
@ -1852,13 +1852,65 @@ def check_gnn_graphdata(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_gnn_graph(method):
|
||||
"""check the input arguments of Graph."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[edges, node_feat, edge_feat, graph_feat, node_type, edge_type, num_parallel_workers, working_mode,
|
||||
hostname, port, num_client, auto_shutdown], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
type_check(edges, (list, np.ndarray), "edges")
|
||||
check_dict(node_feat, str, np.ndarray, "node_feat")
|
||||
check_dict(edge_feat, str, np.ndarray, "edge_feat")
|
||||
check_dict(graph_feat, str, np.ndarray, "graph_feat")
|
||||
if node_type:
|
||||
type_check(node_type, (list, np.ndarray), "node_type")
|
||||
if edge_type:
|
||||
type_check(edge_type, (None, list, np.ndarray), "edge_type")
|
||||
|
||||
# check shape of node_feat and edge_feat
|
||||
num_nodes = np.max(edges) + 1
|
||||
if node_feat and isinstance(node_feat, dict):
|
||||
num_nodes = node_feat[list(node_feat.keys())[0]].shape[0]
|
||||
if node_feat:
|
||||
for key, value in node_feat.items():
|
||||
if len(value.shape) != 2 or value.shape[0] != num_nodes:
|
||||
raise ValueError("value of item '{0}' in node_feat should with shape [num_nodes, num_node_features]"
|
||||
"(here num_nodes is: {1}), but got: {2}".format(key, num_nodes, value.shape))
|
||||
if edge_feat:
|
||||
for key, value in edge_feat.items():
|
||||
if len(value.shape) != 2 or value.shape[0] != edges.shape[1]:
|
||||
raise ValueError("value of item '{0}' in edge_feat should with shape [num_edges, num_node_features]"
|
||||
"(here num_edges is: {1}), but got: {2}".format(key, edges.shape[1], value.shape))
|
||||
|
||||
if num_parallel_workers is not None:
|
||||
check_num_parallel_workers(num_parallel_workers)
|
||||
type_check(hostname, (str,), "hostname")
|
||||
if check_hostname(hostname) is False:
|
||||
raise ValueError("The hostname is illegal.")
|
||||
type_check(working_mode, (str,), "working_mode")
|
||||
if working_mode not in {'local', 'client', 'server'}:
|
||||
raise ValueError("Invalid working mode, please enter 'local', 'client' or 'server'.")
|
||||
type_check(port, (int,), "port")
|
||||
check_value(port, (1024, 65535), "port")
|
||||
type_check(num_client, (int,), "num_client")
|
||||
check_value(num_client, (1, 255), "num_client")
|
||||
type_check(auto_shutdown, (bool,), "auto_shutdown")
|
||||
return method(self, *args, **kwargs)
|
||||
return new_method
|
||||
|
||||
|
||||
def check_gnn_get_all_nodes(method):
|
||||
"""A wrapper that wraps a parameter checker around the GNN `get_all_nodes` function."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[node_type], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(node_type, (int,), "node_type")
|
||||
if "GraphData" in str(type(self)):
|
||||
type_check(node_type, (int,), "node_type")
|
||||
else:
|
||||
type_check(node_type, (str,), "node_type")
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
|
@ -1871,7 +1923,10 @@ def check_gnn_get_all_edges(method):
|
|||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[edge_type], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(edge_type, (int,), "edge_type")
|
||||
if "GraphData" in str(type(self)):
|
||||
type_check(edge_type, (int,), "edge_type")
|
||||
else:
|
||||
type_check(edge_type, (str,), "edge_type")
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
|
@ -1912,7 +1967,10 @@ def check_gnn_get_all_neighbors(method):
|
|||
[node_list, neighbour_type, _], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
check_gnn_list_or_ndarray(node_list, 'node_list')
|
||||
type_check(neighbour_type, (int,), "neighbour_type")
|
||||
if "GraphData" in str(type(self)):
|
||||
type_check(neighbour_type, (int,), "neighbour_type")
|
||||
else:
|
||||
type_check(neighbour_type, (str,), "neighbour_type")
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
|
@ -1933,7 +1991,10 @@ def check_gnn_get_sampled_neighbors(method):
|
|||
raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}.".format(
|
||||
'neighbor_nums', len(neighbor_nums)))
|
||||
|
||||
check_gnn_list_or_ndarray(neighbor_types, 'neighbor_types')
|
||||
if "GraphData" in str(type(self)):
|
||||
check_gnn_list_or_ndarray(neighbor_types, 'neighbor_types')
|
||||
else:
|
||||
check_gnn_list_or_ndarray(neighbor_types, 'neighbor_types', str)
|
||||
if not neighbor_types or len(neighbor_types) > 6:
|
||||
raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}.".format(
|
||||
'neighbor_types', len(neighbor_types)))
|
||||
|
@ -1956,7 +2017,11 @@ def check_gnn_get_neg_sampled_neighbors(method):
|
|||
|
||||
check_gnn_list_or_ndarray(node_list, 'node_list')
|
||||
type_check(neg_neighbor_num, (int,), "neg_neighbor_num")
|
||||
type_check(neg_neighbor_type, (int,), "neg_neighbor_type")
|
||||
|
||||
if "GraphData" in str(type(self)):
|
||||
type_check(neg_neighbor_type, (int,), "neg_neighbor_type")
|
||||
else:
|
||||
type_check(neg_neighbor_type, (str,), "neg_neighbor_type")
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
|
@ -2026,7 +2091,23 @@ def check_gnn_get_node_feature(method):
|
|||
raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
|
||||
node_list, node_list.dtype))
|
||||
|
||||
check_gnn_list_or_ndarray(feature_types, 'feature_types')
|
||||
if "GraphData" in str(type(self)):
|
||||
check_gnn_list_or_ndarray(feature_types, 'feature_types')
|
||||
else:
|
||||
check_gnn_list_or_ndarray(feature_types, 'feature_types', data_type=str)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_gnn_get_graph_feature(method):
|
||||
"""A wrapper that wraps a parameter checker around the GNN `get_graph_feature` function."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[feature_types], _ = parse_user_args(method, *args, **kwargs)
|
||||
check_gnn_list_or_ndarray(feature_types, 'feature_types', str)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
|
@ -2048,7 +2129,10 @@ def check_gnn_get_edge_feature(method):
|
|||
raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
|
||||
edge_list, edge_list.dtype))
|
||||
|
||||
check_gnn_list_or_ndarray(feature_types, 'feature_types')
|
||||
if "GraphData" in str(type(self)):
|
||||
check_gnn_list_or_ndarray(feature_types, 'feature_types')
|
||||
else:
|
||||
check_gnn_list_or_ndarray(feature_types, 'feature_types', data_type=str)
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
|
|
|
@ -100,7 +100,7 @@ class MindDataTestGNNGraph : public UT::Common {
|
|||
/// Expectation: Output is equal to the expected output
|
||||
TEST_F(MindDataTestGNNGraph, TestGetEdgesFromNodes) {
|
||||
std::string path = "data/mindrecord/testGraphData/testdata";
|
||||
GraphDataImpl graph(path, 1);
|
||||
GraphDataImpl graph("mindrecord", path, 1);
|
||||
Status s = graph.Init();
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
|
||||
|
@ -118,7 +118,7 @@ TEST_F(MindDataTestGNNGraph, TestGetEdgesFromNodes) {
|
|||
/// Expectation: Output is equal to the expected output
|
||||
TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
|
||||
std::string path = "data/mindrecord/testGraphData/testdata";
|
||||
GraphDataImpl graph(path, 1);
|
||||
GraphDataImpl graph("mindrecord", path, 1);
|
||||
Status s = graph.Init();
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
|
||||
|
@ -162,7 +162,7 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
|
|||
/// Expectation: Output is equal to the expected output
|
||||
TEST_F(MindDataTestGNNGraph, TestGetAllNeighborsSpecialFormat) {
|
||||
std::string path = "data/mindrecord/testGraphData/testdata";
|
||||
GraphDataImpl graph(path, 1);
|
||||
GraphDataImpl graph("mindrecord", path, 1);
|
||||
Status s = graph.Init();
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
|
||||
|
@ -206,7 +206,7 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighborsSpecialFormat) {
|
|||
/// Expectation: Output is equal to the expected output
|
||||
TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) {
|
||||
std::string path = "data/mindrecord/testGraphData/testdata";
|
||||
GraphDataImpl graph(path, 1);
|
||||
GraphDataImpl graph("mindrecord", path, 1);
|
||||
Status s = graph.Init();
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
|
||||
|
@ -331,7 +331,7 @@ TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) {
|
|||
/// Expectation: Output is equal to the expected output
|
||||
TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) {
|
||||
std::string path = "data/mindrecord/testGraphData/testdata";
|
||||
GraphDataImpl graph(path, 1);
|
||||
GraphDataImpl graph("mindrecord", path, 1);
|
||||
Status s = graph.Init();
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
|
||||
|
@ -377,7 +377,7 @@ TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) {
|
|||
/// Expectation: Output is equal to the expected output
|
||||
TEST_F(MindDataTestGNNGraph, TestRandomWalk) {
|
||||
std::string path = "data/mindrecord/testGraphData/sns";
|
||||
GraphDataImpl graph(path, 1);
|
||||
GraphDataImpl graph("mindrecord", path, 1);
|
||||
Status s = graph.Init();
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
|
||||
|
@ -406,7 +406,7 @@ TEST_F(MindDataTestGNNGraph, TestRandomWalk) {
|
|||
/// Expectation: Output is equal to the expected output
|
||||
TEST_F(MindDataTestGNNGraph, TestRandomWalkDefaults) {
|
||||
std::string path = "data/mindrecord/testGraphData/sns";
|
||||
GraphDataImpl graph(path, 1);
|
||||
GraphDataImpl graph("mindrecord", path, 1);
|
||||
Status s = graph.Init();
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
|
||||
|
|
Binary file not shown.
|
@ -0,0 +1,120 @@
|
|||
import os
|
||||
import random
|
||||
import time
|
||||
from multiprocessing import Process
|
||||
import numpy as np
|
||||
from mindspore import log as logger
|
||||
from mindspore.dataset import Graph
|
||||
from mindspore.dataset import ArgoverseDataset
|
||||
|
||||
|
||||
def test_create_graph_with_edges():
|
||||
"""
|
||||
Feature: Graph
|
||||
Description: test create Graph with loading edge, node_feature, edge_feature
|
||||
Expectation: Output value equals to the expected
|
||||
"""
|
||||
edges = np.array([[1, 2], [0, 1]], dtype=np.int32)
|
||||
node_feat = {"label": np.array([[0], [1], [2]], dtype=np.int32)}
|
||||
edge_feat = {"feature": np.array([[1, 2], [3, 4]], dtype=np.int32)}
|
||||
g = Graph(edges, node_feat, edge_feat)
|
||||
|
||||
graph_info = g.graph_info()
|
||||
assert graph_info['node_type'] == ['0']
|
||||
assert graph_info['edge_type'] == ['0']
|
||||
assert graph_info['node_num'] == {'0': 3}
|
||||
assert graph_info['edge_num'] == {'0': 2}
|
||||
assert graph_info['node_feature_type'] == ['label']
|
||||
assert graph_info['edge_feature_type'] == ['feature']
|
||||
|
||||
all_nodes = g.get_all_nodes('0')
|
||||
assert all_nodes.tolist() == [0, 1, 2]
|
||||
all_edges = g.get_all_edges('0')
|
||||
assert all_edges.tolist() == [0, 1]
|
||||
node_feature = g.get_node_feature([0, 1], ["label"])
|
||||
assert node_feature[0].tolist() == [0, 1]
|
||||
edge_feature = g.get_edge_feature([0], ["feature"])
|
||||
assert edge_feature[0].tolist() == [1, 2]
|
||||
|
||||
|
||||
def start_graph_server_with_array(server_port):
|
||||
"""
|
||||
start graph server.
|
||||
"""
|
||||
edges = np.array([[1, 2], [0, 1]], dtype=np.int32)
|
||||
node_feat = {"label": np.array([[0], [1], [2]], dtype=np.int32)}
|
||||
edge_feat = {"feature": np.array([[1, 2], [3, 4]], dtype=np.int32)}
|
||||
graph_feat = {"feature_1": np.array([1, 2, 3, 4, 5], dtype=np.int32),
|
||||
"feature_2": np.array([11, 12, 13, 14, 15], dtype=np.int32)}
|
||||
Graph(edges, node_feat, edge_feat, graph_feat, working_mode='server', port=server_port)
|
||||
|
||||
|
||||
def test_server_mode_with_array():
|
||||
"""
|
||||
Feature: Graph
|
||||
Description: Test Graph distributed
|
||||
Expectation: Output equals to the expected output
|
||||
"""
|
||||
asan = os.environ.get('ASAN_OPTIONS')
|
||||
if asan:
|
||||
logger.info("skip the Graph distributed when asan mode")
|
||||
return
|
||||
|
||||
server_port = random.randint(10000, 60000)
|
||||
p1 = Process(target=start_graph_server_with_array, args=(server_port,))
|
||||
p1.start()
|
||||
time.sleep(5)
|
||||
|
||||
edges = np.array([[1, 2], [0, 1]], dtype=np.int32)
|
||||
node_feat = {"label": np.array([[0], [1], [2]], dtype=np.int32)}
|
||||
edge_feat = {"feature": np.array([[1, 2], [3, 4]], dtype=np.int32)}
|
||||
graph_feat = {"feature_1": np.array([1, 2, 3, 4, 5], dtype=np.int32),
|
||||
"feature_2": np.array([11, 12, 13, 14, 15], dtype=np.int32)}
|
||||
g = Graph(edges, node_feat, edge_feat, graph_feat, working_mode='client', port=server_port)
|
||||
|
||||
all_nodes = g.get_all_nodes('0')
|
||||
assert all_nodes.tolist() == [0, 1, 2]
|
||||
all_edges = g.get_all_edges('0')
|
||||
assert all_edges.tolist() == [0, 1]
|
||||
node_feature = g.get_node_feature([0, 1], ["label"])
|
||||
assert node_feature[0].tolist() == [0, 1]
|
||||
edge_feature = g.get_edge_feature([0], ["feature"])
|
||||
assert edge_feature[0].tolist() == [1, 2]
|
||||
graph_feature = g.get_graph_feature(["feature_1"])
|
||||
assert graph_feature[0].tolist() == [1, 2, 3, 4, 5]
|
||||
|
||||
|
||||
def test_graph_feature_local():
|
||||
"""
|
||||
Feature: Graph
|
||||
Description: Test load Graph feature in local mode
|
||||
Expectation: Output equals to the expected output
|
||||
"""
|
||||
edges = np.array([[1, 2], [0, 1]], dtype=np.int32)
|
||||
graph_feat = {"feature_1": np.array([1, 2, 3, 4, 5], dtype=np.int32),
|
||||
"feature_2": np.array([11, 12, 13, 14, 15], dtype=np.int32)}
|
||||
|
||||
g = Graph(edges, graph_feat=graph_feat)
|
||||
graph_feature = g.get_graph_feature(["feature_1"])
|
||||
assert graph_feature[0].tolist() == [1, 2, 3, 4, 5]
|
||||
|
||||
|
||||
def test_argoverse_dataset():
|
||||
"""
|
||||
Feature: Graph
|
||||
Description: Test self-implemented dataset which inherit InMemoryGraphDataset
|
||||
Expectation: Output equals to the expected output
|
||||
"""
|
||||
data_dir = "../data/dataset/testArgoverse"
|
||||
graph_dataset = ArgoverseDataset(data_dir,
|
||||
column_names=["edge_index", "x", "y", "cluster", "valid_len", "time_step_len"])
|
||||
for item in graph_dataset.create_dict_iterator(output_numpy=True, num_epochs=1):
|
||||
keys = list(item.keys())
|
||||
assert keys == ["edge_index", "x", "y", "cluster", "valid_len", "time_step_len"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_create_graph_with_edges()
|
||||
test_graph_feature_local()
|
||||
test_server_mode_with_array()
|
||||
test_argoverse_dataset()
|
Loading…
Reference in New Issue