!4498 Gnn data processing supports distributed scenarios

Merge pull request !4498 from heleiwang/gnn_distributed
This commit is contained in:
mindspore-ci-bot 2020-08-19 16:09:05 +08:00 committed by Gitee
commit 256dccc651
48 changed files with 3202 additions and 340 deletions

View File

@ -15,7 +15,14 @@ include(${CMAKE_SOURCE_DIR}/cmake/external_libs/json.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/dependency_securec.cmake) include(${CMAKE_SOURCE_DIR}/cmake/dependency_securec.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake)
SET(MS_BUILD_GRPC 0)
if (ENABLE_DEBUGGER OR ENABLE_SERVING OR ENABLE_TESTCASES) if (ENABLE_DEBUGGER OR ENABLE_SERVING OR ENABLE_TESTCASES)
SET(MS_BUILD_GRPC 1)
endif()
if (ENABLE_MINDDATA AND NOT CMAKE_SYSTEM_NAME MATCHES "Windows")
SET(MS_BUILD_GRPC 1)
endif()
if ("${MS_BUILD_GRPC}")
# build dependencies of gRPC # build dependencies of gRPC
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/absl.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/absl.cmake)
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/c-ares.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/c-ares.cmake)

View File

@ -83,6 +83,7 @@ endif()
if (ENABLE_TDTQUE) if (ENABLE_TDTQUE)
add_dependencies(engine-tdt core) add_dependencies(engine-tdt core)
endif () endif ()
################### Create _c_dataengine Library ###################### ################### Create _c_dataengine Library ######################
set(submodules set(submodules
$<TARGET_OBJECTS:core> $<TARGET_OBJECTS:core>
@ -182,3 +183,7 @@ else()
set_target_properties(_c_dataengine PROPERTIES MACOSX_RPATH ON) set_target_properties(_c_dataengine PROPERTIES MACOSX_RPATH ON)
endif () endif ()
endif() endif()
if (NOT CMAKE_SYSTEM_NAME MATCHES "Windows")
target_link_libraries(_c_dataengine PRIVATE mindspore::grpc++)
endif()

View File

@ -18,83 +18,103 @@
#include "pybind11/stl_bind.h" #include "pybind11/stl_bind.h"
#include "minddata/dataset/api/python/pybind_register.h" #include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/engine/gnn/graph_data_client.h"
#include "minddata/dataset/engine/gnn/graph.h" #include "minddata/dataset/engine/gnn/graph_data_impl.h"
#include "minddata/dataset/engine/gnn/graph_data_server.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
PYBIND_REGISTER( PYBIND_REGISTER(
Graph, 0, ([](const py::module *m) { Graph, 0, ([](const py::module *m) {
(void)py::class_<gnn::Graph, std::shared_ptr<gnn::Graph>>(*m, "Graph") (void)py::class_<gnn::GraphData, std::shared_ptr<gnn::GraphData>>(*m, "GraphDataClient")
.def(py::init([](std::string dataset_file, int32_t num_workers) { .def(py::init([](const std::string &dataset_file, int32_t num_workers, const std::string &working_mode,
std::shared_ptr<gnn::Graph> g_out = std::make_shared<gnn::Graph>(dataset_file, num_workers); const std::string &hostname, int32_t port) {
THROW_IF_ERROR(g_out->Init()); std::shared_ptr<gnn::GraphData> out;
return g_out; if (working_mode == "local") {
out = std::make_shared<gnn::GraphDataImpl>(dataset_file, num_workers);
} else if (working_mode == "client") {
out = std::make_shared<gnn::GraphDataClient>(dataset_file, hostname, port);
}
THROW_IF_ERROR(out->Init());
return out;
})) }))
.def("get_all_nodes", .def("get_all_nodes",
[](gnn::Graph &g, gnn::NodeType node_type) { [](gnn::GraphData &g, gnn::NodeType node_type) {
std::shared_ptr<Tensor> out; std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetAllNodes(node_type, &out)); THROW_IF_ERROR(g.GetAllNodes(node_type, &out));
return out; return out;
}) })
.def("get_all_edges", .def("get_all_edges",
[](gnn::Graph &g, gnn::EdgeType edge_type) { [](gnn::GraphData &g, gnn::EdgeType edge_type) {
std::shared_ptr<Tensor> out; std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetAllEdges(edge_type, &out)); THROW_IF_ERROR(g.GetAllEdges(edge_type, &out));
return out; return out;
}) })
.def("get_nodes_from_edges", .def("get_nodes_from_edges",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> edge_list) { [](gnn::GraphData &g, std::vector<gnn::NodeIdType> edge_list) {
std::shared_ptr<Tensor> out; std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out)); THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out));
return out; return out;
}) })
.def("get_all_neighbors", .def("get_all_neighbors",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type) { [](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type) {
std::shared_ptr<Tensor> out; std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out)); THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out));
return out; return out;
}) })
.def("get_sampled_neighbors", .def("get_sampled_neighbors",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums, [](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums,
std::vector<gnn::NodeType> neighbor_types) { std::vector<gnn::NodeType> neighbor_types) {
std::shared_ptr<Tensor> out; std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out)); THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out));
return out; return out;
}) })
.def("get_neg_sampled_neighbors", .def("get_neg_sampled_neighbors",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeIdType neighbor_num, [](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeIdType neighbor_num,
gnn::NodeType neg_neighbor_type) { gnn::NodeType neg_neighbor_type) {
std::shared_ptr<Tensor> out; std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out)); THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out));
return out; return out;
}) })
.def("get_node_feature", .def("get_node_feature",
[](gnn::Graph &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) { [](gnn::GraphData &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) {
TensorRow out; TensorRow out;
THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out)); THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out));
return out.getRow(); return out.getRow();
}) })
.def("get_edge_feature", .def("get_edge_feature",
[](gnn::Graph &g, std::shared_ptr<Tensor> edge_list, std::vector<gnn::FeatureType> feature_types) { [](gnn::GraphData &g, std::shared_ptr<Tensor> edge_list, std::vector<gnn::FeatureType> feature_types) {
TensorRow out; TensorRow out;
THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out)); THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out));
return out.getRow(); return out.getRow();
}) })
.def("graph_info", .def("graph_info",
[](gnn::Graph &g) { [](gnn::GraphData &g) {
py::dict out; py::dict out;
THROW_IF_ERROR(g.GraphInfo(&out)); THROW_IF_ERROR(g.GraphInfo(&out));
return out; return out;
}) })
.def("random_walk", .def("random_walk",
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeType> meta_path, [](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeType> meta_path,
float step_home_param, float step_away_param, gnn::NodeIdType default_node) { float step_home_param, float step_away_param, gnn::NodeIdType default_node) {
std::shared_ptr<Tensor> out; std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out)); THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out));
return out; return out;
}); })
.def("stop", [](gnn::GraphData &g) { THROW_IF_ERROR(g.Stop()); });
(void)py::class_<gnn::GraphDataServer, std::shared_ptr<gnn::GraphDataServer>>(*m, "GraphDataServer")
.def(py::init([](const std::string &dataset_file, int32_t num_workers, const std::string &hostname, int32_t port,
int32_t client_num, bool auto_shutdown) {
std::shared_ptr<gnn::GraphDataServer> out;
out =
std::make_shared<gnn::GraphDataServer>(dataset_file, num_workers, hostname, port, client_num, auto_shutdown);
THROW_IF_ERROR(out->Init());
return out;
}))
.def("stop", [](gnn::GraphDataServer &g) { THROW_IF_ERROR(g.Stop()); })
.def("is_stoped", [](gnn::GraphDataServer &g) { return g.IsStoped(); });
})); }));
} // namespace dataset } // namespace dataset

View File

@ -1,9 +1,29 @@
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(engine-gnn OBJECT set(DATASET_ENGINE_GNN_SRC_FILES
graph.cc graph_data_impl.cc
graph_data_client.cc
graph_data_server.cc
graph_loader.cc graph_loader.cc
graph_feature_parser.cc
local_node.cc local_node.cc
local_edge.cc local_edge.cc
feature.cc feature.cc
) )
if (CMAKE_SYSTEM_NAME MATCHES "Windows")
add_library(engine-gnn OBJECT ${DATASET_ENGINE_GNN_SRC_FILES})
else()
set(DATASET_ENGINE_GNN_SRC_FILES
${DATASET_ENGINE_GNN_SRC_FILES}
tensor_proto.cc
grpc_async_server.cc
graph_data_service_impl.cc
graph_shared_memory.cc)
ms_protobuf_generate(TENSOR_PROTO_SRCS TENSOR_PROTO_HDRS "gnn_tensor.proto")
ms_grpc_generate(GNN_PROTO_SRCS GNN_PROTO_HDRS "gnn_graph_data.proto")
add_library(engine-gnn OBJECT ${DATASET_ENGINE_GNN_SRC_FILES} ${TENSOR_PROTO_SRCS} ${GNN_PROTO_SRCS})
add_dependencies(engine-gnn mindspore::protobuf)
endif()

View File

@ -19,7 +19,8 @@ namespace mindspore {
namespace dataset { namespace dataset {
namespace gnn { namespace gnn {
Feature::Feature(FeatureType type_name, std::shared_ptr<Tensor> value) : type_name_(type_name), value_(value) {} Feature::Feature(FeatureType type_name, std::shared_ptr<Tensor> value, bool is_shared_memory)
: type_name_(type_name), value_(value), is_shared_memory_(is_shared_memory) {}
} // namespace gnn } // namespace gnn
} // namespace dataset } // namespace dataset

View File

@ -31,7 +31,7 @@ class Feature {
// Constructor // Constructor
// @param FeatureType type_name - feature type // @param FeatureType type_name - feature type
// @param std::shared_ptr<Tensor> value - feature value // @param std::shared_ptr<Tensor> value - feature value
Feature(FeatureType type_name, std::shared_ptr<Tensor> value); Feature(FeatureType type_name, std::shared_ptr<Tensor> value, bool is_shared_memory = false);
~Feature() = default; ~Feature() = default;
@ -45,6 +45,7 @@ class Feature {
private: private:
FeatureType type_name_; FeatureType type_name_;
std::shared_ptr<Tensor> value_; std::shared_ptr<Tensor> value_;
bool is_shared_memory_;
}; };
} // namespace gnn } // namespace gnn
} // namespace dataset } // namespace dataset

View File

@ -0,0 +1,103 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
syntax = "proto3";
package mindspore.dataset;
import "gnn_tensor.proto";
message GnnClientRegisterRequestPb {
int32 pid = 1;
}
message GnnFeatureInfoPb {
int32 type = 1;
TensorPb feature = 2;
}
message GnnClientRegisterResponsePb {
string error_msg = 1;
string data_schema = 2;
int64 shared_memory_key = 3;
int64 shared_memory_size = 4;
repeated GnnFeatureInfoPb default_node_feature = 5;
repeated GnnFeatureInfoPb default_edge_feature = 6;
}
message GnnClientUnRegisterRequestPb {
int32 pid = 1;
}
message GnnClientUnRegisterResponsePb {
string error_msg = 1;
}
enum GnnOpName {
GET_ALL_NODES = 0;
GET_ALL_EDGES = 1;
GET_NODES_FROM_EDGES = 2;
GET_ALL_NEIGHBORS = 3;
GET_SAMPLED_NEIGHBORS = 4;
GET_NEG_SAMPLED_NEIGHBORS = 5;
RANDOM_WALK = 6;
GET_NODE_FEATURE = 7;
GET_EDGE_FEATURE = 8;
}
message GnnRandomWalkPb {
float p = 1;
float q = 2;
int32 default_id = 3;
}
message GnnGraphDataRequestPb {
GnnOpName op_name = 1;
repeated int32 id = 2; // node id or edge id
repeated int32 type = 3; //node type or edge type or neighbor type or feature type
repeated int32 number = 4; // samples number
TensorPb id_tensor = 5; // input ids ,node id or edge id
GnnRandomWalkPb random_walk = 6;
}
message GnnGraphDataResponsePb {
string error_msg = 1;
repeated TensorPb result_data = 2;
}
message GnnMetaInfoRequestPb {
}
message GnnNodeEdgeInfoPb {
int32 type = 1;
int32 num = 2;
}
message GnnMetaInfoResponsePb {
string error_msg = 1;
repeated GnnNodeEdgeInfoPb node_info = 2;
repeated GnnNodeEdgeInfoPb edge_info = 3;
repeated int32 node_feature_type = 4;
repeated int32 edge_feature_type = 5;
}
service GnnGraphData {
rpc ClientRegister(GnnClientRegisterRequestPb) returns (GnnClientRegisterResponsePb);
rpc ClientUnRegister(GnnClientUnRegisterRequestPb) returns (GnnClientUnRegisterResponsePb);
rpc GetGraphData(GnnGraphDataRequestPb) returns (GnnGraphDataResponsePb);
rpc GetMetaInfo(GnnMetaInfoRequestPb) returns (GnnMetaInfoResponsePb);
}

View File

@ -0,0 +1,42 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
syntax = "proto3";
package mindspore.dataset;
enum DataTypePb {
DE_PB_UNKNOWN = 0;
DE_PB_BOOL = 1;
DE_PB_INT8 = 2;
DE_PB_UINT8 = 3;
DE_PB_INT16 = 4;
DE_PB_UINT16 = 5;
DE_PB_INT32 = 6;
DE_PB_UINT32 = 7;
DE_PB_INT64 = 8;
DE_PB_UINT64 = 9;
DE_PB_FLOAT16 = 10;
DE_PB_FLOAT32 = 11;
DE_PB_FLOAT64 = 12;
DE_PB_STRING = 13;
}
message TensorPb {
repeated int64 dims = 1; // tensor shape info
DataTypePb tensor_type = 2; // tensor content data type
bytes data = 3; // tensor data
}

View File

@ -0,0 +1,134 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include <utility>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/core/tensor_row.h"
#include "minddata/dataset/engine/gnn/feature.h"
#include "minddata/dataset/engine/gnn/node.h"
#include "minddata/dataset/engine/gnn/edge.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace gnn {
struct MetaInfo {
std::vector<NodeType> node_type;
std::vector<EdgeType> edge_type;
std::map<NodeType, NodeIdType> node_num;
std::map<EdgeType, EdgeIdType> edge_num;
std::vector<FeatureType> node_feature_type;
std::vector<FeatureType> edge_feature_type;
};
class GraphData {
public:
// Get all nodes from the graph.
// @param NodeType node_type - type of node
// @param std::shared_ptr<Tensor> *out - Returned nodes id
// @return Status - The error code return
virtual Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) = 0;
// Get all edges from the graph.
// @param NodeType edge_type - type of edge
// @param std::shared_ptr<Tensor> *out - Returned edge ids
// @return Status - The error code return
virtual Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) = 0;
// Get the node id from the edge.
// @param std::vector<EdgeIdType> edge_list - List of edges
// @param std::shared_ptr<Tensor> *out - Returned node ids
// @return Status - The error code return
virtual Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) = 0;
// All neighbors of the acquisition node.
// @param std::vector<NodeType> node_list - List of nodes
// @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
// is not enough, fill in tensor as -1.
// @return Status - The error code return
virtual Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
std::shared_ptr<Tensor> *out) = 0;
// Get sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
// @return Status - The error code return
virtual Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) = 0;
// Get negative sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
// @param NodeIdType samples_num - Number of neighbors sampled
// @param NodeType neg_neighbor_type - The type of negative neighbor.
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
// @return Status - The error code return
virtual Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) = 0;
// Node2vec random walk.
// @param std::vector<NodeIdType> node_list - List of nodes
// @param std::vector<NodeType> meta_path - node type of each step
// @param float step_home_param - return hyper parameter in node2vec algorithm
// @param float step_away_param - inout hyper parameter in node2vec algorithm
// @param NodeIdType default_node - default node id
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
// @return Status - The error code return
virtual Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param, float step_away_param, NodeIdType default_node,
std::shared_ptr<Tensor> *out) = 0;
// Get the feature of a node
// @param std::shared_ptr<Tensor> nodes - List of nodes
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param TensorRow *out - Returned features
// @return Status - The error code return
virtual Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
TensorRow *out) = 0;
// Get the feature of a edge
// @param std::shared_ptr<Tensor> edges - List of edges
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param Tensor *out - Returned features
// @return Status - The error code return
virtual Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
TensorRow *out) = 0;
// Return meta information to python layer
virtual Status GraphInfo(py::dict *out) = 0;
virtual Status Init() = 0;
virtual Status Stop() = 0;
};
} // namespace gnn
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_

View File

@ -0,0 +1,589 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/gnn/graph_data_client.h"
#include <unistd.h>
#include <functional>
#include <map>
#if !defined(_WIN32) && !defined(_WIN64)
#include "grpcpp/grpcpp.h"
#endif
#include "minddata/dataset/core/data_type.h"
#if !defined(_WIN32) && !defined(_WIN64)
#include "minddata/dataset/engine/gnn/tensor_proto.h"
#endif
namespace mindspore {
namespace dataset {
namespace gnn {
GraphDataClient::GraphDataClient(const std::string &dataset_file, const std::string &hostname, int32_t port)
: dataset_file_(dataset_file),
host_(hostname),
port_(port),
pid_(0),
#if !defined(_WIN32) && !defined(_WIN64)
shared_memory_key_(-1),
shared_memory_size_(0),
graph_feature_parser_(nullptr),
graph_shared_memory_(nullptr),
#endif
registered_(false) {
}
GraphDataClient::~GraphDataClient() { (void)Stop(); }
Status GraphDataClient::Init() {
#if defined(_WIN32) || defined(_WIN64)
RETURN_STATUS_UNEXPECTED("Graph data client is not supported in Windows OS");
#else
if (!registered_) {
std::string server_address;
server_address = host_ + ":" + std::to_string(port_);
MS_LOG(INFO) << "Graph data client starting. address:" << server_address;
pid_ = getpid();
grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(-1);
std::shared_ptr<grpc::Channel> channel =
grpc::CreateCustomChannel(server_address, grpc::InsecureChannelCredentials(), args);
stub_ = GnnGraphData::NewStub(channel);
Status status = RegisterToServer();
while (status.ToString().find("Initializing") != std::string::npos) {
MS_LOG(INFO) << "Graph data server is initializing, please wait.";
std::this_thread::sleep_for(std::chrono::milliseconds(2000));
status = RegisterToServer();
}
RETURN_IF_NOT_OK(status);
MS_LOG(INFO) << "Graph data client successfully registered with server " << server_address;
}
RETURN_IF_NOT_OK(InitFeatureParser());
return Status::OK();
#endif
}
Status GraphDataClient::Stop() {
#if !defined(_WIN32) && !defined(_WIN64)
if (registered_) {
UnRegisterToServer();
}
#endif
return Status::OK();
}
Status GraphDataClient::GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) {
#if !defined(_WIN32) && !defined(_WIN64)
GnnGraphDataRequestPb request;
GnnGraphDataResponsePb response;
request.set_op_name(GET_ALL_NODES);
request.add_type(static_cast<google::protobuf::int32>(node_type));
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
#endif
return Status::OK();
}
Status GraphDataClient::GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) {
#if !defined(_WIN32) && !defined(_WIN64)
GnnGraphDataRequestPb request;
GnnGraphDataResponsePb response;
request.set_op_name(GET_ALL_EDGES);
request.add_type(static_cast<google::protobuf::int32>(edge_type));
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
#endif
return Status::OK();
}
Status GraphDataClient::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) {
#if !defined(_WIN32) && !defined(_WIN64)
GnnGraphDataRequestPb request;
GnnGraphDataResponsePb response;
request.set_op_name(GET_NODES_FROM_EDGES);
for (const auto &edge_id : edge_list) {
request.add_id(static_cast<google::protobuf::int32>(edge_id));
}
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
#endif
return Status::OK();
}
Status GraphDataClient::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
std::shared_ptr<Tensor> *out) {
#if !defined(_WIN32) && !defined(_WIN64)
GnnGraphDataRequestPb request;
GnnGraphDataResponsePb response;
request.set_op_name(GET_ALL_NEIGHBORS);
for (const auto &node_id : node_list) {
request.add_id(static_cast<google::protobuf::int32>(node_id));
}
request.add_type(static_cast<google::protobuf::int32>(neighbor_type));
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
#endif
return Status::OK();
}
Status GraphDataClient::GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) {
#if !defined(_WIN32) && !defined(_WIN64)
GnnGraphDataRequestPb request;
GnnGraphDataResponsePb response;
request.set_op_name(GET_SAMPLED_NEIGHBORS);
for (const auto &node_id : node_list) {
request.add_id(static_cast<google::protobuf::int32>(node_id));
}
for (const auto &num : neighbor_nums) {
request.add_number(static_cast<google::protobuf::int32>(num));
}
for (const auto &type : neighbor_types) {
request.add_type(static_cast<google::protobuf::int32>(type));
}
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
#endif
return Status::OK();
}
Status GraphDataClient::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) {
#if !defined(_WIN32) && !defined(_WIN64)
GnnGraphDataRequestPb request;
GnnGraphDataResponsePb response;
request.set_op_name(GET_NEG_SAMPLED_NEIGHBORS);
for (const auto &node_id : node_list) {
request.add_id(static_cast<google::protobuf::int32>(node_id));
}
request.add_number(static_cast<google::protobuf::int32>(samples_num));
request.add_type(static_cast<google::protobuf::int32>(neg_neighbor_type));
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
#endif
return Status::OK();
}
Status GraphDataClient::GraphDataClient::RandomWalk(const std::vector<NodeIdType> &node_list,
const std::vector<NodeType> &meta_path, float step_home_param,
float step_away_param, NodeIdType default_node,
std::shared_ptr<Tensor> *out) {
#if !defined(_WIN32) && !defined(_WIN64)
GnnGraphDataRequestPb request;
GnnGraphDataResponsePb response;
request.set_op_name(RANDOM_WALK);
for (const auto &node_id : node_list) {
request.add_id(static_cast<google::protobuf::int32>(node_id));
}
for (const auto &type : meta_path) {
request.add_type(static_cast<google::protobuf::int32>(type));
}
auto walk_param = request.mutable_random_walk();
walk_param->set_p(step_home_param);
walk_param->set_q(step_away_param);
walk_param->set_default_id(static_cast<google::protobuf::int32>(default_node));
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
#endif
return Status::OK();
}
Status GraphDataClient::GetNodeFeature(const std::shared_ptr<Tensor> &nodes,
const std::vector<FeatureType> &feature_types, TensorRow *out) {
#if !defined(_WIN32) && !defined(_WIN64)
if (!nodes || nodes->Size() == 0) {
RETURN_STATUS_UNEXPECTED("Input nodes is empty");
}
CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Input feature_types is empty");
GnnGraphDataRequestPb request;
GnnGraphDataResponsePb response;
request.set_op_name(GET_NODE_FEATURE);
for (const auto &type : feature_types) {
request.add_type(static_cast<google::protobuf::int32>(type));
}
RETURN_IF_NOT_OK(TensorToPb(nodes, request.mutable_id_tensor()));
RETURN_IF_NOT_OK(GetGraphData(request, &response));
CHECK_FAIL_RETURN_UNEXPECTED(feature_types.size() == response.result_data().size(),
"The number of feature types returned by the server is wrong");
if (response.result_data().size() > 0) {
size_t i = 0;
for (const auto &result : response.result_data()) {
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(PbToTensor(&result, &tensor));
std::shared_ptr<Tensor> fea_tensor;
RETURN_IF_NOT_OK(ParseNodeFeatureFromMemory(nodes, feature_types[i], tensor, &fea_tensor));
out->emplace_back(std::move(fea_tensor));
++i;
}
} else {
RETURN_STATUS_UNEXPECTED("RPC failed: The number of returned tensor is abnormal");
}
#endif
return Status::OK();
}
Status GraphDataClient::GetEdgeFeature(const std::shared_ptr<Tensor> &edges,
const std::vector<FeatureType> &feature_types, TensorRow *out) {
#if !defined(_WIN32) && !defined(_WIN64)
if (!edges || edges->Size() == 0) {
RETURN_STATUS_UNEXPECTED("Input edges is empty");
}
CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Input feature_types is empty");
GnnGraphDataRequestPb request;
GnnGraphDataResponsePb response;
request.set_op_name(GET_EDGE_FEATURE);
for (const auto &type : feature_types) {
request.add_type(static_cast<google::protobuf::int32>(type));
}
RETURN_IF_NOT_OK(TensorToPb(edges, request.mutable_id_tensor()));
RETURN_IF_NOT_OK(GetGraphData(request, &response));
CHECK_FAIL_RETURN_UNEXPECTED(feature_types.size() == response.result_data().size(),
"The number of feature types returned by the server is wrong");
if (response.result_data().size() > 0) {
size_t i = 0;
for (const auto &result : response.result_data()) {
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(PbToTensor(&result, &tensor));
std::shared_ptr<Tensor> fea_tensor;
RETURN_IF_NOT_OK(ParseEdgeFeatureFromMemory(edges, feature_types[i], tensor, &fea_tensor));
out->emplace_back(std::move(fea_tensor));
++i;
}
} else {
RETURN_STATUS_UNEXPECTED("RPC failed: The number of returned tensor is abnormal");
}
#endif
return Status::OK();
}
Status GraphDataClient::GraphInfo(py::dict *out) {
#if !defined(_WIN32) && !defined(_WIN64)
RETURN_IF_NOT_OK(CheckPid());
void *tag;
bool ok;
grpc::Status status;
grpc::ClientContext ctx;
grpc::CompletionQueue cq;
GnnMetaInfoRequestPb request;
GnnMetaInfoResponsePb response;
// One minute timeout
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60);
ctx.set_deadline(deadline);
std::unique_ptr<grpc::ClientAsyncResponseReader<GnnMetaInfoResponsePb>> rpc(
stub_->PrepareAsyncGetMetaInfo(&ctx, request, &cq));
rpc->StartCall();
rpc->Finish(&response, &status, &response);
{
py::gil_scoped_release gil_release;
auto success = cq.Next(&tag, &ok);
CHECK_FAIL_RETURN_UNEXPECTED(success, "Expect successful");
CHECK_FAIL_RETURN_UNEXPECTED(tag == &response, "Expect the same tag");
CHECK_FAIL_RETURN_UNEXPECTED(ok, "Expect successful");
}
if (status.ok()) {
if (response.error_msg() != "Success") {
RETURN_STATUS_UNEXPECTED(response.error_msg());
} else {
MetaInfo meta_info;
for (const auto &node : response.node_info()) {
meta_info.node_type.emplace_back(static_cast<NodeType>(node.type()));
meta_info.node_num[static_cast<NodeType>(node.type())] = static_cast<NodeIdType>(node.num());
}
for (const auto &edge : response.edge_info()) {
meta_info.edge_type.emplace_back(static_cast<EdgeType>(edge.type()));
meta_info.edge_num[static_cast<EdgeType>(edge.type())] = static_cast<EdgeIdType>(edge.num());
}
for (const auto &feature_type : response.node_feature_type()) {
meta_info.node_feature_type.emplace_back(static_cast<FeatureType>(feature_type));
}
for (const auto &feature_type : response.edge_feature_type()) {
meta_info.edge_feature_type.emplace_back(static_cast<FeatureType>(feature_type));
}
(*out)["node_type"] = py::cast(meta_info.node_type);
(*out)["edge_type"] = py::cast(meta_info.edge_type);
(*out)["node_num"] = py::cast(meta_info.node_num);
(*out)["edge_num"] = py::cast(meta_info.edge_num);
(*out)["node_feature_type"] = py::cast(meta_info.node_feature_type);
(*out)["edge_feature_type"] = py::cast(meta_info.edge_feature_type);
}
} else {
auto error_code = status.error_code();
RETURN_STATUS_UNEXPECTED(status.error_message() + ". GRPC Code " + std::to_string(error_code));
}
#endif
return Status::OK();
}
#if !defined(_WIN32) && !defined(_WIN64)
Status GraphDataClient::GetGraphData(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response) {
RETURN_IF_NOT_OK(CheckPid());
void *tag;
bool ok;
grpc::Status status;
grpc::ClientContext ctx;
grpc::CompletionQueue cq;
// One minute timeout
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60);
ctx.set_deadline(deadline);
std::unique_ptr<grpc::ClientAsyncResponseReader<GnnGraphDataResponsePb>> rpc(
stub_->PrepareAsyncGetGraphData(&ctx, request, &cq));
rpc->StartCall();
rpc->Finish(response, &status, response);
{
py::gil_scoped_release gil_release;
auto success = cq.Next(&tag, &ok);
CHECK_FAIL_RETURN_UNEXPECTED(success, "Expect successful");
CHECK_FAIL_RETURN_UNEXPECTED(tag == response, "Expect the same tag");
CHECK_FAIL_RETURN_UNEXPECTED(ok, "Expect successful");
}
if (status.ok()) {
if (response->error_msg() != "Success") {
RETURN_STATUS_UNEXPECTED(response->error_msg());
}
} else {
auto error_code = status.error_code();
RETURN_STATUS_UNEXPECTED(status.error_message() + ". GRPC Code " + std::to_string(error_code));
}
return Status::OK();
}
Status GraphDataClient::GetGraphDataTensor(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response,
std::shared_ptr<Tensor> *out) {
RETURN_IF_NOT_OK(GetGraphData(request, response));
if (1 == response->result_data().size()) {
const TensorPb &result = response->result_data()[0];
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(PbToTensor(&result, &tensor));
*out = std::move(tensor);
} else {
RETURN_STATUS_UNEXPECTED("RPC failed: The number of returned tensor is abnormal");
}
return Status::OK();
}
Status GraphDataClient::ParseNodeFeatureFromMemory(const std::shared_ptr<Tensor> &nodes, FeatureType feature_type,
const std::shared_ptr<Tensor> &memory_tensor,
std::shared_ptr<Tensor> *out) {
std::shared_ptr<Tensor> default_feature;
// If no feature can be obtained, fill in the default value
RETURN_IF_NOT_OK(GetNodeDefaultFeature(feature_type, &default_feature));
TensorShape shape(default_feature->shape());
auto shape_vec = nodes->shape().AsVector();
dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies<dsize_t>());
shape = shape.PrependDim(size);
std::shared_ptr<Tensor> fea_tensor;
RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, default_feature->type(), &fea_tensor));
dsize_t index = 0;
auto fea_addr_itr = memory_tensor->begin<int64_t>();
for (auto node_itr = nodes->begin<NodeIdType>(); node_itr != nodes->end<NodeIdType>(); ++node_itr) {
int64_t offset = *fea_addr_itr;
fea_addr_itr++;
int64_t len = *fea_addr_itr;
fea_addr_itr++;
if (*node_itr == kDefaultNodeId || offset < 0 || len <= 0) {
RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, default_feature));
} else {
uchar *start_addr_of_index = nullptr;
TensorShape remaining({-1});
RETURN_IF_NOT_OK(fea_tensor->StartAddrOfIndex({index}, &start_addr_of_index, &remaining));
RETURN_IF_NOT_OK(graph_shared_memory_->GetData(start_addr_of_index, len, offset, len));
}
index++;
}
TensorShape reshape(nodes->shape());
for (auto s : default_feature->shape().AsVector()) {
reshape = reshape.AppendDim(s);
}
RETURN_IF_NOT_OK(fea_tensor->Reshape(reshape));
fea_tensor->Squeeze();
*out = std::move(fea_tensor);
return Status::OK();
}
Status GraphDataClient::ParseEdgeFeatureFromMemory(const std::shared_ptr<Tensor> &edges, FeatureType feature_type,
const std::shared_ptr<Tensor> &memory_tensor,
std::shared_ptr<Tensor> *out) {
std::shared_ptr<Tensor> default_feature;
// If no feature can be obtained, fill in the default value
RETURN_IF_NOT_OK(GetEdgeDefaultFeature(feature_type, &default_feature));
TensorShape shape(default_feature->shape());
auto shape_vec = edges->shape().AsVector();
dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies<dsize_t>());
shape = shape.PrependDim(size);
std::shared_ptr<Tensor> fea_tensor;
RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, default_feature->type(), &fea_tensor));
dsize_t index = 0;
auto fea_addr_itr = memory_tensor->begin<int64_t>();
for (auto edge_itr = edges->begin<EdgeIdType>(); edge_itr != edges->end<EdgeIdType>(); ++edge_itr) {
int64_t offset = *fea_addr_itr;
fea_addr_itr++;
int64_t len = *fea_addr_itr;
fea_addr_itr++;
if (offset < 0 || len <= 0) {
RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, default_feature));
} else {
uchar *start_addr_of_index = nullptr;
TensorShape remaining({-1});
RETURN_IF_NOT_OK(fea_tensor->StartAddrOfIndex({index}, &start_addr_of_index, &remaining));
RETURN_IF_NOT_OK(graph_shared_memory_->GetData(start_addr_of_index, len, offset, len));
}
index++;
}
TensorShape reshape(edges->shape());
for (auto s : default_feature->shape().AsVector()) {
reshape = reshape.AppendDim(s);
}
RETURN_IF_NOT_OK(fea_tensor->Reshape(reshape));
fea_tensor->Squeeze();
*out = std::move(fea_tensor);
return Status::OK();
}
Status GraphDataClient::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature) {
auto itr = default_node_feature_map_.find(feature_type);
if (itr == default_node_feature_map_.end()) {
std::string err_msg = "Invalid feature type:" + std::to_string(feature_type);
RETURN_STATUS_UNEXPECTED(err_msg);
} else {
*out_feature = itr->second;
}
return Status::OK();
}
Status GraphDataClient::GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature) {
auto itr = default_edge_feature_map_.find(feature_type);
if (itr == default_edge_feature_map_.end()) {
std::string err_msg = "Invalid feature type:" + std::to_string(feature_type);
RETURN_STATUS_UNEXPECTED(err_msg);
} else {
*out_feature = itr->second;
}
return Status::OK();
}
Status GraphDataClient::RegisterToServer() {
RETURN_IF_NOT_OK(CheckPid());
void *tag;
bool ok;
grpc::Status status;
grpc::ClientContext ctx;
grpc::CompletionQueue cq;
GnnClientRegisterRequestPb request;
GnnClientRegisterResponsePb response;
request.set_pid(static_cast<google::protobuf::int32>(pid_));
// One minute timeout
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60);
ctx.set_deadline(deadline);
std::unique_ptr<grpc::ClientAsyncResponseReader<GnnClientRegisterResponsePb>> rpc(
stub_->PrepareAsyncClientRegister(&ctx, request, &cq));
rpc->StartCall();
rpc->Finish(&response, &status, &response);
{
py::gil_scoped_release gil_release;
auto success = cq.Next(&tag, &ok);
CHECK_FAIL_RETURN_UNEXPECTED(success, "Expect successful");
CHECK_FAIL_RETURN_UNEXPECTED(tag == &response, "Expect the same tag");
CHECK_FAIL_RETURN_UNEXPECTED(ok, "Expect successful");
}
if (status.ok()) {
if (response.error_msg() == "Success") {
registered_ = true;
data_schema_ = mindrecord::json::parse(response.data_schema());
shared_memory_key_ = static_cast<key_t>(response.shared_memory_key());
shared_memory_size_ = response.shared_memory_size();
MS_LOG(INFO) << "Register success, recv data_schema:" << response.data_schema();
for (auto feature_info : response.default_node_feature()) {
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(PbToTensor(&feature_info.feature(), &tensor));
default_node_feature_map_[feature_info.type()] = tensor;
}
for (auto feature_info : response.default_edge_feature()) {
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(PbToTensor(&feature_info.feature(), &tensor));
default_edge_feature_map_[feature_info.type()] = tensor;
}
} else {
RETURN_STATUS_UNEXPECTED(response.error_msg());
}
} else {
auto error_code = status.error_code();
RETURN_STATUS_UNEXPECTED(status.error_message() + ". GRPC Code " + std::to_string(error_code));
}
return Status::OK();
}
Status GraphDataClient::UnRegisterToServer() {
RETURN_IF_NOT_OK(CheckPid());
MS_LOG(INFO) << "Graph data client send unregistered to server ";
void *tag;
bool ok;
grpc::Status status;
grpc::ClientContext ctx;
grpc::CompletionQueue cq;
GnnClientUnRegisterRequestPb request;
GnnClientUnRegisterResponsePb response;
request.set_pid(static_cast<google::protobuf::int32>(pid_));
// One minute timeout
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60);
ctx.set_deadline(deadline);
std::unique_ptr<grpc::ClientAsyncResponseReader<GnnClientUnRegisterResponsePb>> rpc(
stub_->PrepareAsyncClientUnRegister(&ctx, request, &cq));
rpc->StartCall();
rpc->Finish(&response, &status, &response);
{
py::gil_scoped_release gil_release;
auto success = cq.Next(&tag, &ok);
CHECK_FAIL_RETURN_UNEXPECTED(success, "Expect successful");
CHECK_FAIL_RETURN_UNEXPECTED(tag == &response, "Expect the same tag");
CHECK_FAIL_RETURN_UNEXPECTED(ok, "Expect successful");
}
if (status.ok()) {
if (response.error_msg() == "Success") {
MS_LOG(INFO) << "Unregister success.";
registered_ = false;
} else {
RETURN_STATUS_UNEXPECTED(response.error_msg());
}
} else {
auto error_code = status.error_code();
RETURN_STATUS_UNEXPECTED(status.error_message() + ". GRPC Code " + std::to_string(error_code));
}
return Status::OK();
}
Status GraphDataClient::InitFeatureParser() {
// get shared memory
graph_shared_memory_ = std::make_unique<GraphSharedMemory>(shared_memory_size_, shared_memory_key_);
RETURN_IF_NOT_OK(graph_shared_memory_->GetSharedMemory());
// build feature parser
graph_feature_parser_ = std::make_unique<GraphFeatureParser>(ShardColumn(data_schema_));
return Status::OK();
}
#endif
} // namespace gnn
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,185 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_
#include <algorithm>
#include <memory>
#include <string>
#include <map>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <utility>
#if !defined(_WIN32) && !defined(_WIN64)
#include "proto/gnn_graph_data.grpc.pb.h"
#include "proto/gnn_graph_data.pb.h"
#endif
#include "minddata/dataset/engine/gnn/graph_data.h"
#include "minddata/dataset/engine/gnn/graph_feature_parser.h"
#if !defined(_WIN32) && !defined(_WIN64)
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
#endif
#include "minddata/mindrecord/include/common/shard_utils.h"
#include "minddata/mindrecord/include/shard_column.h"
namespace mindspore {
namespace dataset {
namespace gnn {
class GraphDataClient : public GraphData {
public:
// Constructor
// @param std::string dataset_file -
// @param int32_t num_workers - number of parallel threads
GraphDataClient(const std::string &dataset_file, const std::string &hostname, int32_t port);
~GraphDataClient();
Status Init() override;
Status Stop() override;
// Get all nodes from the graph.
// @param NodeType node_type - type of node
// @param std::shared_ptr<Tensor> *out - Returned nodes id
// @return Status - The error code return
Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override;
// Get all edges from the graph.
// @param NodeType edge_type - type of edge
// @param std::shared_ptr<Tensor> *out - Returned edge ids
// @return Status - The error code return
Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override;
// Get the node id from the edge.
// @param std::vector<EdgeIdType> edge_list - List of edges
// @param std::shared_ptr<Tensor> *out - Returned node ids
// @return Status - The error code return
Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override;
// All neighbors of the acquisition node.
// @param std::vector<NodeType> node_list - List of nodes
// @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
// is not enough, fill in tensor as -1.
// @return Status - The error code return
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
std::shared_ptr<Tensor> *out) override;
// Get sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
// @return Status - The error code return
Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override;
// Get negative sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes
// @param NodeIdType samples_num - Number of neighbors sampled
// @param NodeType neg_neighbor_type - The type of negative neighbor.
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
// @return Status - The error code return
Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override;
// Node2vec random walk.
// @param std::vector<NodeIdType> node_list - List of nodes
// @param std::vector<NodeType> meta_path - node type of each step
// @param float step_home_param - return hyper parameter in node2vec algorithm
// @param float step_away_param - inout hyper parameter in node2vec algorithm
// @param NodeIdType default_node - default node id
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
// @return Status - The error code return
Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param, float step_away_param, NodeIdType default_node,
std::shared_ptr<Tensor> *out) override;
// Get the feature of a node
// @param std::shared_ptr<Tensor> nodes - List of nodes
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param TensorRow *out - Returned features
// @return Status - The error code return
Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
TensorRow *out) override;
// Get the feature of a edge
// @param std::shared_ptr<Tensor> edges - List of edges
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist.
// @param Tensor *out - Returned features
// @return Status - The error code return
Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
TensorRow *out) override;
// Return meta information to python layer
Status GraphInfo(py::dict *out) override;
private:
#if !defined(_WIN32) && !defined(_WIN64)
Status ParseNodeFeatureFromMemory(const std::shared_ptr<Tensor> &nodes, FeatureType feature_type,
const std::shared_ptr<Tensor> &memory_tensor, std::shared_ptr<Tensor> *out);
Status ParseEdgeFeatureFromMemory(const std::shared_ptr<Tensor> &edges, FeatureType feature_type,
const std::shared_ptr<Tensor> &memory_tensor, std::shared_ptr<Tensor> *out);
Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature);
Status GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature);
Status GetGraphData(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response);
Status GetGraphDataTensor(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response,
std::shared_ptr<Tensor> *out);
Status RegisterToServer();
Status UnRegisterToServer();
Status InitFeatureParser();
Status CheckPid() {
CHECK_FAIL_RETURN_UNEXPECTED(pid_ == getpid(),
"Multi-process mode is not supported, please change to use multi-thread");
return Status::OK();
}
#endif
std::string dataset_file_;
std::string host_;
int32_t port_;
int32_t pid_;
mindrecord::json data_schema_;
#if !defined(_WIN32) && !defined(_WIN64)
std::unique_ptr<GnnGraphData::Stub> stub_;
key_t shared_memory_key_;
int64_t shared_memory_size_;
std::unique_ptr<GraphFeatureParser> graph_feature_parser_;
std::unique_ptr<GraphSharedMemory> graph_shared_memory_;
std::unordered_map<FeatureType, std::shared_ptr<Tensor>> default_node_feature_map_;
std::unordered_map<FeatureType, std::shared_ptr<Tensor>> default_edge_feature_map_;
#endif
bool registered_;
};
} // namespace gnn
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_

View File

@ -13,7 +13,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "minddata/dataset/engine/gnn/graph.h" #include "minddata/dataset/engine/gnn/graph_data_impl.h"
#include <algorithm> #include <algorithm>
#include <functional> #include <functional>
@ -22,19 +22,25 @@
#include <utility> #include <utility>
#include "minddata/dataset/core/tensor_shape.h" #include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/engine/gnn/graph_loader.h"
#include "minddata/dataset/util/random.h" #include "minddata/dataset/util/random.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace gnn { namespace gnn {
Graph::Graph(std::string dataset_file, int32_t num_workers) GraphDataImpl::GraphDataImpl(std::string dataset_file, int32_t num_workers, bool server_mode)
: dataset_file_(dataset_file), num_workers_(num_workers), rnd_(GetRandomDevice()), random_walk_(this) { : dataset_file_(dataset_file),
num_workers_(num_workers),
rnd_(GetRandomDevice()),
random_walk_(this),
server_mode_(server_mode) {
rnd_.seed(GetSeed()); rnd_.seed(GetSeed());
MS_LOG(INFO) << "num_workers:" << num_workers; MS_LOG(INFO) << "num_workers:" << num_workers;
} }
Status Graph::GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) { GraphDataImpl::~GraphDataImpl() {}
Status GraphDataImpl::GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) {
auto itr = node_type_map_.find(node_type); auto itr = node_type_map_.find(node_type);
if (itr == node_type_map_.end()) { if (itr == node_type_map_.end()) {
std::string err_msg = "Invalid node type:" + std::to_string(node_type); std::string err_msg = "Invalid node type:" + std::to_string(node_type);
@ -46,8 +52,8 @@ Status Graph::GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) {
} }
template <typename T> template <typename T>
Status Graph::CreateTensorByVector(const std::vector<std::vector<T>> &data, DataType type, Status GraphDataImpl::CreateTensorByVector(const std::vector<std::vector<T>> &data, DataType type,
std::shared_ptr<Tensor> *out) { std::shared_ptr<Tensor> *out) {
if (!type.IsCompatible<T>()) { if (!type.IsCompatible<T>()) {
RETURN_STATUS_UNEXPECTED("Data type not compatible"); RETURN_STATUS_UNEXPECTED("Data type not compatible");
} }
@ -72,7 +78,7 @@ Status Graph::CreateTensorByVector(const std::vector<std::vector<T>> &data, Data
} }
template <typename T> template <typename T>
Status Graph::ComplementVector(std::vector<std::vector<T>> *data, size_t max_size, T default_value) { Status GraphDataImpl::ComplementVector(std::vector<std::vector<T>> *data, size_t max_size, T default_value) {
if (!data || data->empty()) { if (!data || data->empty()) {
RETURN_STATUS_UNEXPECTED("Input data is empty"); RETURN_STATUS_UNEXPECTED("Input data is empty");
} }
@ -89,7 +95,7 @@ Status Graph::ComplementVector(std::vector<std::vector<T>> *data, size_t max_siz
return Status::OK(); return Status::OK();
} }
Status Graph::GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) { Status GraphDataImpl::GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) {
auto itr = edge_type_map_.find(edge_type); auto itr = edge_type_map_.find(edge_type);
if (itr == edge_type_map_.end()) { if (itr == edge_type_map_.end()) {
std::string err_msg = "Invalid edge type:" + std::to_string(edge_type); std::string err_msg = "Invalid edge type:" + std::to_string(edge_type);
@ -100,7 +106,7 @@ Status Graph::GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) {
return Status::OK(); return Status::OK();
} }
Status Graph::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) { Status GraphDataImpl::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) {
if (edge_list.empty()) { if (edge_list.empty()) {
RETURN_STATUS_UNEXPECTED("Input edge_list is empty"); RETURN_STATUS_UNEXPECTED("Input edge_list is empty");
} }
@ -122,8 +128,8 @@ Status Graph::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::s
return Status::OK(); return Status::OK();
} }
Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, Status GraphDataImpl::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
std::shared_ptr<Tensor> *out) { std::shared_ptr<Tensor> *out) {
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
RETURN_IF_NOT_OK(CheckNeighborType(neighbor_type)); RETURN_IF_NOT_OK(CheckNeighborType(neighbor_type));
@ -143,7 +149,7 @@ Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType
return Status::OK(); return Status::OK();
} }
Status Graph::CheckSamplesNum(NodeIdType samples_num) { Status GraphDataImpl::CheckSamplesNum(NodeIdType samples_num) {
NodeIdType all_nodes_number = NodeIdType all_nodes_number =
std::accumulate(node_type_map_.begin(), node_type_map_.end(), 0, std::accumulate(node_type_map_.begin(), node_type_map_.end(), 0,
[](NodeIdType t1, const auto &t2) -> NodeIdType { return t1 + t2.second.size(); }); [](NodeIdType t1, const auto &t2) -> NodeIdType { return t1 + t2.second.size(); });
@ -155,7 +161,7 @@ Status Graph::CheckSamplesNum(NodeIdType samples_num) {
return Status::OK(); return Status::OK();
} }
Status Graph::CheckNeighborType(NodeType neighbor_type) { Status GraphDataImpl::CheckNeighborType(NodeType neighbor_type) {
if (node_type_map_.find(neighbor_type) == node_type_map_.end()) { if (node_type_map_.find(neighbor_type) == node_type_map_.end()) {
std::string err_msg = "Invalid neighbor type:" + std::to_string(neighbor_type); std::string err_msg = "Invalid neighbor type:" + std::to_string(neighbor_type);
RETURN_STATUS_UNEXPECTED(err_msg); RETURN_STATUS_UNEXPECTED(err_msg);
@ -163,9 +169,9 @@ Status Graph::CheckNeighborType(NodeType neighbor_type) {
return Status::OK(); return Status::OK();
} }
Status Graph::GetSampledNeighbors(const std::vector<NodeIdType> &node_list, Status GraphDataImpl::GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
const std::vector<NodeIdType> &neighbor_nums, const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) { const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) {
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(), CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(),
"The sizes of neighbor_nums and neighbor_types are inconsistent."); "The sizes of neighbor_nums and neighbor_types are inconsistent.");
@ -205,8 +211,9 @@ Status Graph::GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
return Status::OK(); return Status::OK();
} }
Status Graph::NegativeSample(const std::vector<NodeIdType> &data, const std::unordered_set<NodeIdType> &exclude_data, Status GraphDataImpl::NegativeSample(const std::vector<NodeIdType> &data,
int32_t samples_num, std::vector<NodeIdType> *out_samples) { const std::unordered_set<NodeIdType> &exclude_data, int32_t samples_num,
std::vector<NodeIdType> *out_samples) {
CHECK_FAIL_RETURN_UNEXPECTED(!data.empty(), "Input data is empty."); CHECK_FAIL_RETURN_UNEXPECTED(!data.empty(), "Input data is empty.");
std::vector<NodeIdType> shuffled_id(data.size()); std::vector<NodeIdType> shuffled_id(data.size());
std::iota(shuffled_id.begin(), shuffled_id.end(), 0); std::iota(shuffled_id.begin(), shuffled_id.end(), 0);
@ -223,8 +230,8 @@ Status Graph::NegativeSample(const std::vector<NodeIdType> &data, const std::uno
return Status::OK(); return Status::OK();
} }
Status Graph::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, Status GraphDataImpl::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) { NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) {
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
RETURN_IF_NOT_OK(CheckSamplesNum(samples_num)); RETURN_IF_NOT_OK(CheckSamplesNum(samples_num));
RETURN_IF_NOT_OK(CheckNeighborType(neg_neighbor_type)); RETURN_IF_NOT_OK(CheckNeighborType(neg_neighbor_type));
@ -260,9 +267,9 @@ Status Graph::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, N
return Status::OK(); return Status::OK();
} }
Status Graph::RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, Status GraphDataImpl::RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param, float step_away_param, NodeIdType default_node, float step_home_param, float step_away_param, NodeIdType default_node,
std::shared_ptr<Tensor> *out) { std::shared_ptr<Tensor> *out) {
RETURN_IF_NOT_OK(random_walk_.Build(node_list, meta_path, step_home_param, step_away_param, default_node)); RETURN_IF_NOT_OK(random_walk_.Build(node_list, meta_path, step_home_param, step_away_param, default_node));
std::vector<std::vector<NodeIdType>> walks; std::vector<std::vector<NodeIdType>> walks;
RETURN_IF_NOT_OK(random_walk_.SimulateWalk(&walks)); RETURN_IF_NOT_OK(random_walk_.SimulateWalk(&walks));
@ -270,7 +277,7 @@ Status Graph::RandomWalk(const std::vector<NodeIdType> &node_list, const std::ve
return Status::OK(); return Status::OK();
} }
Status Graph::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) { Status GraphDataImpl::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) {
auto itr = default_node_feature_map_.find(feature_type); auto itr = default_node_feature_map_.find(feature_type);
if (itr == default_node_feature_map_.end()) { if (itr == default_node_feature_map_.end()) {
std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); std::string err_msg = "Invalid feature type:" + std::to_string(feature_type);
@ -281,7 +288,7 @@ Status Graph::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Fe
return Status::OK(); return Status::OK();
} }
Status Graph::GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) { Status GraphDataImpl::GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) {
auto itr = default_edge_feature_map_.find(feature_type); auto itr = default_edge_feature_map_.find(feature_type);
if (itr == default_edge_feature_map_.end()) { if (itr == default_edge_feature_map_.end()) {
std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); std::string err_msg = "Invalid feature type:" + std::to_string(feature_type);
@ -292,8 +299,8 @@ Status Graph::GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Fe
return Status::OK(); return Status::OK();
} }
Status Graph::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types, Status GraphDataImpl::GetNodeFeature(const std::shared_ptr<Tensor> &nodes,
TensorRow *out) { const std::vector<FeatureType> &feature_types, TensorRow *out) {
if (!nodes || nodes->Size() == 0) { if (!nodes || nodes->Size() == 0) {
RETURN_STATUS_UNEXPECTED("Input nodes is empty"); RETURN_STATUS_UNEXPECTED("Input nodes is empty");
} }
@ -339,8 +346,49 @@ Status Graph::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::ve
return Status::OK(); return Status::OK();
} }
Status Graph::GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types, Status GraphDataImpl::GetNodeFeatureSharedMemory(const std::shared_ptr<Tensor> &nodes, FeatureType type,
TensorRow *out) { std::shared_ptr<Tensor> *out) {
if (!nodes || nodes->Size() == 0) {
RETURN_STATUS_UNEXPECTED("Input nodes is empty");
}
TensorShape shape = nodes->shape().AppendDim(2);
std::shared_ptr<Tensor> fea_tensor;
RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, DataType(DataType::DE_INT64), &fea_tensor));
auto out_fea_itr = fea_tensor->begin<int64_t>();
for (auto node_itr = nodes->begin<NodeIdType>(); node_itr != nodes->end<NodeIdType>(); ++node_itr) {
if (*node_itr == kDefaultNodeId) {
*out_fea_itr = -1;
++out_fea_itr;
*out_fea_itr = -1;
++out_fea_itr;
} else {
std::shared_ptr<Node> node;
RETURN_IF_NOT_OK(GetNodeByNodeId(*node_itr, &node));
std::shared_ptr<Feature> feature;
if (!node->GetFeatures(type, &feature).IsOk()) {
*out_fea_itr = -1;
++out_fea_itr;
*out_fea_itr = -1;
++out_fea_itr;
} else {
for (auto fea_itr = feature->Value()->begin<int64_t>(); fea_itr != feature->Value()->end<int64_t>();
++fea_itr) {
*out_fea_itr = *fea_itr;
++out_fea_itr;
}
}
}
}
fea_tensor->Squeeze();
*out = std::move(fea_tensor);
return Status::OK();
}
Status GraphDataImpl::GetEdgeFeature(const std::shared_ptr<Tensor> &edges,
const std::vector<FeatureType> &feature_types, TensorRow *out) {
if (!edges || edges->Size() == 0) { if (!edges || edges->Size() == 0) {
RETURN_STATUS_UNEXPECTED("Input edges is empty"); RETURN_STATUS_UNEXPECTED("Input edges is empty");
} }
@ -382,12 +430,45 @@ Status Graph::GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::ve
return Status::OK(); return Status::OK();
} }
Status Graph::Init() { Status GraphDataImpl::GetEdgeFeatureSharedMemory(const std::shared_ptr<Tensor> &edges, FeatureType type,
std::shared_ptr<Tensor> *out) {
if (!edges || edges->Size() == 0) {
RETURN_STATUS_UNEXPECTED("Input edges is empty");
}
TensorShape shape = edges->shape().AppendDim(2);
std::shared_ptr<Tensor> fea_tensor;
RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, DataType(DataType::DE_INT64), &fea_tensor));
auto out_fea_itr = fea_tensor->begin<int64_t>();
for (auto edge_itr = edges->begin<EdgeIdType>(); edge_itr != edges->end<EdgeIdType>(); ++edge_itr) {
std::shared_ptr<Edge> edge;
RETURN_IF_NOT_OK(GetEdgeByEdgeId(*edge_itr, &edge));
std::shared_ptr<Feature> feature;
if (!edge->GetFeatures(type, &feature).IsOk()) {
*out_fea_itr = -1;
++out_fea_itr;
*out_fea_itr = -1;
++out_fea_itr;
} else {
for (auto fea_itr = feature->Value()->begin<int64_t>(); fea_itr != feature->Value()->end<int64_t>(); ++fea_itr) {
*out_fea_itr = *fea_itr;
++out_fea_itr;
}
}
}
fea_tensor->Squeeze();
*out = std::move(fea_tensor);
return Status::OK();
}
Status GraphDataImpl::Init() {
RETURN_IF_NOT_OK(LoadNodeAndEdge()); RETURN_IF_NOT_OK(LoadNodeAndEdge());
return Status::OK(); return Status::OK();
} }
Status Graph::GetMetaInfo(MetaInfo *meta_info) { Status GraphDataImpl::GetMetaInfo(MetaInfo *meta_info) {
meta_info->node_type.resize(node_type_map_.size()); meta_info->node_type.resize(node_type_map_.size());
std::transform(node_type_map_.begin(), node_type_map_.end(), meta_info->node_type.begin(), std::transform(node_type_map_.begin(), node_type_map_.end(), meta_info->node_type.begin(),
[](auto itr) { return itr.first; }); [](auto itr) { return itr.first; });
@ -427,7 +508,7 @@ Status Graph::GetMetaInfo(MetaInfo *meta_info) {
} }
#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
Status Graph::GraphInfo(py::dict *out) { Status GraphDataImpl::GraphInfo(py::dict *out) {
MetaInfo meta_info; MetaInfo meta_info;
RETURN_IF_NOT_OK(GetMetaInfo(&meta_info)); RETURN_IF_NOT_OK(GetMetaInfo(&meta_info));
(*out)["node_type"] = py::cast(meta_info.node_type); (*out)["node_type"] = py::cast(meta_info.node_type);
@ -440,18 +521,16 @@ Status Graph::GraphInfo(py::dict *out) {
} }
#endif #endif
Status Graph::LoadNodeAndEdge() { Status GraphDataImpl::LoadNodeAndEdge() {
GraphLoader gl(dataset_file_, num_workers_); GraphLoader gl(this, dataset_file_, num_workers_, server_mode_);
// ask graph_loader to load everything into memory // ask graph_loader to load everything into memory
RETURN_IF_NOT_OK(gl.InitAndLoad()); RETURN_IF_NOT_OK(gl.InitAndLoad());
// get all maps // get all maps
RETURN_IF_NOT_OK(gl.GetNodesAndEdges(&node_id_map_, &edge_id_map_, &node_type_map_, &edge_type_map_, RETURN_IF_NOT_OK(gl.GetNodesAndEdges());
&node_feature_map_, &edge_feature_map_, &default_node_feature_map_,
&default_edge_feature_map_));
return Status::OK(); return Status::OK();
} }
Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node) { Status GraphDataImpl::GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node) {
auto itr = node_id_map_.find(id); auto itr = node_id_map_.find(id);
if (itr == node_id_map_.end()) { if (itr == node_id_map_.end()) {
std::string err_msg = "Invalid node id:" + std::to_string(id); std::string err_msg = "Invalid node id:" + std::to_string(id);
@ -462,7 +541,7 @@ Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node) {
return Status::OK(); return Status::OK();
} }
Status Graph::GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr<Edge> *edge) { Status GraphDataImpl::GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr<Edge> *edge) {
auto itr = edge_id_map_.find(id); auto itr = edge_id_map_.find(id);
if (itr == edge_id_map_.end()) { if (itr == edge_id_map_.end()) {
std::string err_msg = "Invalid edge id:" + std::to_string(id); std::string err_msg = "Invalid edge id:" + std::to_string(id);
@ -473,12 +552,13 @@ Status Graph::GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr<Edge> *edge) {
return Status::OK(); return Status::OK();
} }
Graph::RandomWalkBase::RandomWalkBase(Graph *graph) GraphDataImpl::RandomWalkBase::RandomWalkBase(GraphDataImpl *graph)
: graph_(graph), step_home_param_(1.0), step_away_param_(1.0), default_node_(-1), num_walks_(1), num_workers_(1) {} : graph_(graph), step_home_param_(1.0), step_away_param_(1.0), default_node_(-1), num_walks_(1), num_workers_(1) {}
Status Graph::RandomWalkBase::Build(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, Status GraphDataImpl::RandomWalkBase::Build(const std::vector<NodeIdType> &node_list,
float step_home_param, float step_away_param, const NodeIdType default_node, const std::vector<NodeType> &meta_path, float step_home_param,
int32_t num_walks, int32_t num_workers) { float step_away_param, const NodeIdType default_node, int32_t num_walks,
int32_t num_workers) {
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
node_list_ = node_list; node_list_ = node_list;
if (meta_path.empty() || meta_path.size() > kMaxNumWalks) { if (meta_path.empty() || meta_path.size() > kMaxNumWalks) {
@ -516,7 +596,7 @@ Status Graph::RandomWalkBase::Build(const std::vector<NodeIdType> &node_list, co
return Status::OK(); return Status::OK();
} }
Status Graph::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::vector<NodeIdType> *walk_path) { Status GraphDataImpl::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::vector<NodeIdType> *walk_path) {
// Simulate a random walk starting from start node. // Simulate a random walk starting from start node.
auto walk = std::vector<NodeIdType>(1, start_node); // walk is an vector auto walk = std::vector<NodeIdType>(1, start_node); // walk is an vector
// walk simulate // walk simulate
@ -556,8 +636,8 @@ Status Graph::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::ve
return Status::OK(); return Status::OK();
} }
Status Graph::RandomWalkBase::SimulateWalk(std::vector<std::vector<NodeIdType>> *walks) { Status GraphDataImpl::RandomWalkBase::SimulateWalk(std::vector<std::vector<NodeIdType>> *walks) {
for (int32_t i = 0; i < num_walks_; i++) { for (int32_t i = 0; i < num_walks_; ++i) {
for (const auto &node : node_list_) { for (const auto &node : node_list_) {
std::vector<NodeIdType> walk; std::vector<NodeIdType> walk;
RETURN_IF_NOT_OK(Node2vecWalk(node, &walk)); RETURN_IF_NOT_OK(Node2vecWalk(node, &walk));
@ -567,8 +647,8 @@ Status Graph::RandomWalkBase::SimulateWalk(std::vector<std::vector<NodeIdType>>
return Status::OK(); return Status::OK();
} }
Status Graph::RandomWalkBase::GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type, Status GraphDataImpl::RandomWalkBase::GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type,
std::shared_ptr<StochasticIndex> *node_probability) { std::shared_ptr<StochasticIndex> *node_probability) {
// Generate alias nodes // Generate alias nodes
std::shared_ptr<Node> node; std::shared_ptr<Node> node;
graph_->GetNodeByNodeId(node_id, &node); graph_->GetNodeByNodeId(node_id, &node);
@ -581,8 +661,9 @@ Status Graph::RandomWalkBase::GetNodeProbability(const NodeIdType &node_id, cons
return Status::OK(); return Status::OK();
} }
Status Graph::RandomWalkBase::GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index, Status GraphDataImpl::RandomWalkBase::GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst,
std::shared_ptr<StochasticIndex> *edge_probability) { uint32_t meta_path_index,
std::shared_ptr<StochasticIndex> *edge_probability) {
// Get the alias edge setup lists for a given edge. // Get the alias edge setup lists for a given edge.
std::shared_ptr<Node> src_node; std::shared_ptr<Node> src_node;
graph_->GetNodeByNodeId(src, &src_node); graph_->GetNodeByNodeId(src, &src_node);
@ -616,7 +697,7 @@ Status Graph::RandomWalkBase::GetEdgeProbability(const NodeIdType &src, const No
return Status::OK(); return Status::OK();
} }
StochasticIndex Graph::RandomWalkBase::GenerateProbability(const std::vector<float> &probability) { StochasticIndex GraphDataImpl::RandomWalkBase::GenerateProbability(const std::vector<float> &probability) {
uint32_t K = probability.size(); uint32_t K = probability.size();
std::vector<int32_t> switch_to_large_index(K, 0); std::vector<int32_t> switch_to_large_index(K, 0);
std::vector<float> weight(K, .0); std::vector<float> weight(K, .0);
@ -644,7 +725,7 @@ StochasticIndex Graph::RandomWalkBase::GenerateProbability(const std::vector<flo
return StochasticIndex(switch_to_large_index, weight); return StochasticIndex(switch_to_large_index, weight);
} }
uint32_t Graph::RandomWalkBase::WalkToNextNode(const StochasticIndex &stochastic_index) { uint32_t GraphDataImpl::RandomWalkBase::WalkToNextNode(const StochasticIndex &stochastic_index) {
auto switch_to_large_index = stochastic_index.first; auto switch_to_large_index = stochastic_index.first;
auto weight = stochastic_index.second; auto weight = stochastic_index.second;
const uint32_t size_of_index = switch_to_large_index.size(); const uint32_t size_of_index = switch_to_large_index.size();
@ -662,7 +743,7 @@ uint32_t Graph::RandomWalkBase::WalkToNextNode(const StochasticIndex &stochastic
} }
template <typename T> template <typename T>
std::vector<float> Graph::RandomWalkBase::Normalize(const std::vector<T> &non_normalized_probability) { std::vector<float> GraphDataImpl::RandomWalkBase::Normalize(const std::vector<T> &non_normalized_probability) {
float sum_probability = float sum_probability =
1.0 * std::accumulate(non_normalized_probability.begin(), non_normalized_probability.end(), 0); 1.0 * std::accumulate(non_normalized_probability.begin(), non_normalized_probability.end(), 0);
if (sum_probability < kGnnEpsilon) { if (sum_probability < kGnnEpsilon) {

View File

@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_IMPL_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_IMPL_H_
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
@ -25,13 +25,11 @@
#include <vector> #include <vector>
#include <utility> #include <utility>
#include "minddata/dataset/core/tensor.h" #include "minddata/dataset/engine/gnn/graph_data.h"
#include "minddata/dataset/core/tensor_row.h" #if !defined(_WIN32) && !defined(_WIN64)
#include "minddata/dataset/engine/gnn/graph_loader.h" #include "minddata/dataset/engine/gnn/graph_shared_memory.h"
#include "minddata/dataset/engine/gnn/feature.h" #endif
#include "minddata/dataset/engine/gnn/node.h" #include "minddata/mindrecord/include/common/shard_utils.h"
#include "minddata/dataset/engine/gnn/edge.h"
#include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
@ -41,41 +39,32 @@ const float kGnnEpsilon = 0.0001;
const uint32_t kMaxNumWalks = 80; const uint32_t kMaxNumWalks = 80;
using StochasticIndex = std::pair<std::vector<int32_t>, std::vector<float>>; using StochasticIndex = std::pair<std::vector<int32_t>, std::vector<float>>;
struct MetaInfo { class GraphDataImpl : public GraphData {
std::vector<NodeType> node_type;
std::vector<EdgeType> edge_type;
std::map<NodeType, NodeIdType> node_num;
std::map<EdgeType, EdgeIdType> edge_num;
std::vector<FeatureType> node_feature_type;
std::vector<FeatureType> edge_feature_type;
};
class Graph {
public: public:
// Constructor // Constructor
// @param std::string dataset_file - // @param std::string dataset_file -
// @param int32_t num_workers - number of parallel threads // @param int32_t num_workers - number of parallel threads
Graph(std::string dataset_file, int32_t num_workers); GraphDataImpl(std::string dataset_file, int32_t num_workers, bool server_mode = false);
~Graph() = default; ~GraphDataImpl();
// Get all nodes from the graph. // Get all nodes from the graph.
// @param NodeType node_type - type of node // @param NodeType node_type - type of node
// @param std::shared_ptr<Tensor> *out - Returned nodes id // @param std::shared_ptr<Tensor> *out - Returned nodes id
// @return Status - The error code return // @return Status - The error code return
Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out); Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override;
// Get all edges from the graph. // Get all edges from the graph.
// @param NodeType edge_type - type of edge // @param NodeType edge_type - type of edge
// @param std::shared_ptr<Tensor> *out - Returned edge ids // @param std::shared_ptr<Tensor> *out - Returned edge ids
// @return Status - The error code return // @return Status - The error code return
Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out); Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override;
// Get the node id from the edge. // Get the node id from the edge.
// @param std::vector<EdgeIdType> edge_list - List of edges // @param std::vector<EdgeIdType> edge_list - List of edges
// @param std::shared_ptr<Tensor> *out - Returned node ids // @param std::shared_ptr<Tensor> *out - Returned node ids
// @return Status - The error code return // @return Status - The error code return
Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out); Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override;
// All neighbors of the acquisition node. // All neighbors of the acquisition node.
// @param std::vector<NodeType> node_list - List of nodes // @param std::vector<NodeType> node_list - List of nodes
@ -85,7 +74,7 @@ class Graph {
// is not enough, fill in tensor as -1. // is not enough, fill in tensor as -1.
// @return Status - The error code return // @return Status - The error code return
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
std::shared_ptr<Tensor> *out); std::shared_ptr<Tensor> *out) override;
// Get sampled neighbors. // Get sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes // @param std::vector<NodeType> node_list - List of nodes
@ -94,7 +83,7 @@ class Graph {
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. // @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
// @return Status - The error code return // @return Status - The error code return
Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums, Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out); const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override;
// Get negative sampled neighbors. // Get negative sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes // @param std::vector<NodeType> node_list - List of nodes
@ -103,7 +92,7 @@ class Graph {
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id. // @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
// @return Status - The error code return // @return Status - The error code return
Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num, Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out); NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override;
// Node2vec random walk. // Node2vec random walk.
// @param std::vector<NodeIdType> node_list - List of nodes // @param std::vector<NodeIdType> node_list - List of nodes
@ -115,7 +104,7 @@ class Graph {
// @return Status - The error code return // @return Status - The error code return
Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param, float step_away_param, NodeIdType default_node, float step_home_param, float step_away_param, NodeIdType default_node,
std::shared_ptr<Tensor> *out); std::shared_ptr<Tensor> *out) override;
// Get the feature of a node // Get the feature of a node
// @param std::shared_ptr<Tensor> nodes - List of nodes // @param std::shared_ptr<Tensor> nodes - List of nodes
@ -124,16 +113,22 @@ class Graph {
// @param TensorRow *out - Returned features // @param TensorRow *out - Returned features
// @return Status - The error code return // @return Status - The error code return
Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types, Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
TensorRow *out); TensorRow *out) override;
Status GetNodeFeatureSharedMemory(const std::shared_ptr<Tensor> &nodes, FeatureType type,
std::shared_ptr<Tensor> *out);
// Get the feature of a edge // Get the feature of a edge
// @param std::shared_ptr<Tensor> edget - List of edges // @param std::shared_ptr<Tensor> edges - List of edges
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type // @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
// does not exist. // does not exist.
// @param Tensor *out - Returned features // @param Tensor *out - Returned features
// @return Status - The error code return // @return Status - The error code return
Status GetEdgeFeature(const std::shared_ptr<Tensor> &edget, const std::vector<FeatureType> &feature_types, Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
TensorRow *out); TensorRow *out) override;
Status GetEdgeFeatureSharedMemory(const std::shared_ptr<Tensor> &edges, FeatureType type,
std::shared_ptr<Tensor> *out);
// Get meta information of graph // Get meta information of graph
// @param MetaInfo *meta_info - Returned meta information // @param MetaInfo *meta_info - Returned meta information
@ -142,15 +137,34 @@ class Graph {
#ifdef ENABLE_PYTHON #ifdef ENABLE_PYTHON
// Return meta information to python layer // Return meta information to python layer
Status GraphInfo(py::dict *out); Status GraphInfo(py::dict *out) override;
#endif #endif
Status Init(); const std::unordered_map<FeatureType, std::shared_ptr<Feature>> *GetAllDefaultNodeFeatures() {
return &default_node_feature_map_;
}
const std::unordered_map<FeatureType, std::shared_ptr<Feature>> *GetAllDefaultEdgeFeatures() {
return &default_edge_feature_map_;
}
Status Init() override;
Status Stop() override { return Status::OK(); }
std::string GetDataSchema() { return data_schema_.dump(); }
#if !defined(_WIN32) && !defined(_WIN64)
key_t GetSharedMemoryKey() { return graph_shared_memory_->memory_key(); }
int64_t GetSharedMemorySize() { return graph_shared_memory_->memory_size(); }
#endif
private: private:
friend class GraphLoader;
class RandomWalkBase { class RandomWalkBase {
public: public:
explicit RandomWalkBase(Graph *graph); explicit RandomWalkBase(GraphDataImpl *graph);
Status Build(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path, Status Build(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
float step_home_param = 1.0, float step_away_param = 1.0, NodeIdType default_node = -1, float step_home_param = 1.0, float step_away_param = 1.0, NodeIdType default_node = -1,
@ -176,7 +190,7 @@ class Graph {
template <typename T> template <typename T>
std::vector<float> Normalize(const std::vector<T> &non_normalized_probability); std::vector<float> Normalize(const std::vector<T> &non_normalized_probability);
Graph *graph_; GraphDataImpl *graph_;
std::vector<NodeIdType> node_list_; std::vector<NodeIdType> node_list_;
std::vector<NodeType> meta_path_; std::vector<NodeType> meta_path_;
float step_home_param_; // Return hyper parameter. Default is 1.0 float step_home_param_; // Return hyper parameter. Default is 1.0
@ -248,7 +262,11 @@ class Graph {
int32_t num_workers_; // The number of worker threads int32_t num_workers_; // The number of worker threads
std::mt19937 rnd_; std::mt19937 rnd_;
RandomWalkBase random_walk_; RandomWalkBase random_walk_;
mindrecord::json data_schema_;
bool server_mode_;
#if !defined(_WIN32) && !defined(_WIN64)
std::unique_ptr<GraphSharedMemory> graph_shared_memory_;
#endif
std::unordered_map<NodeType, std::vector<NodeIdType>> node_type_map_; std::unordered_map<NodeType, std::vector<NodeIdType>> node_type_map_;
std::unordered_map<NodeIdType, std::shared_ptr<Node>> node_id_map_; std::unordered_map<NodeIdType, std::shared_ptr<Node>> node_id_map_;
@ -264,4 +282,4 @@ class Graph {
} // namespace gnn } // namespace gnn
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_IMPL_H_

View File

@ -0,0 +1,133 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/gnn/graph_data_server.h"
#include <algorithm>
#include <functional>
#include <iterator>
#include <numeric>
#include <utility>
#include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
#include "minddata/dataset/util/random.h"
namespace mindspore {
namespace dataset {
namespace gnn {
GraphDataServer::GraphDataServer(const std::string &dataset_file, int32_t num_workers, const std::string &hostname,
int32_t port, int32_t client_num, bool auto_shutdown)
: dataset_file_(dataset_file),
num_workers_(num_workers),
client_num_(client_num),
max_connected_client_num_(0),
auto_shutdown_(auto_shutdown),
state_(kGdsUninit) {
tg_ = std::make_unique<TaskGroup>();
graph_data_impl_ = std::make_unique<GraphDataImpl>(dataset_file, num_workers, true);
#if !defined(_WIN32) && !defined(_WIN64)
service_impl_ = std::make_unique<GraphDataServiceImpl>(this, graph_data_impl_.get());
async_server_ = std::make_unique<GraphDataGrpcServer>(hostname, port, service_impl_.get());
#endif
}
Status GraphDataServer::Init() {
#if defined(_WIN32) || defined(_WIN64)
RETURN_STATUS_UNEXPECTED("Graph data server is not supported in Windows OS");
#else
set_state(kGdsInitializing);
RETURN_IF_NOT_OK(async_server_->Run());
// RETURN_IF_NOT_OK(InitGraphDataImpl());
RETURN_IF_NOT_OK(tg_->CreateAsyncTask("init graph data impl", std::bind(&GraphDataServer::InitGraphDataImpl, this)));
for (int32_t i = 0; i < num_workers_; ++i) {
RETURN_IF_NOT_OK(
tg_->CreateAsyncTask("start async rpc service", std::bind(&GraphDataServer::StartAsyncRpcService, this)));
}
if (auto_shutdown_) {
RETURN_IF_NOT_OK(
tg_->CreateAsyncTask("judge auto shutdown server", std::bind(&GraphDataServer::JudgeAutoShutdownServer, this)));
}
return Status::OK();
#endif
}
Status GraphDataServer::InitGraphDataImpl() {
TaskManager::FindMe()->Post();
Status s = graph_data_impl_->Init();
if (s.IsOk()) {
set_state(kGdsRunning);
} else {
(void)Stop();
}
return s;
}
#if !defined(_WIN32) && !defined(_WIN64)
Status GraphDataServer::StartAsyncRpcService() {
TaskManager::FindMe()->Post();
RETURN_IF_NOT_OK(async_server_->HandleRequest());
return Status::OK();
}
#endif
Status GraphDataServer::JudgeAutoShutdownServer() {
TaskManager::FindMe()->Post();
while (true) {
if (auto_shutdown_ && (max_connected_client_num_ >= client_num_) && (client_pid_.size() == 0)) {
MS_LOG(INFO) << "All clients have been unregister, automatically exit the server.";
RETURN_IF_NOT_OK(Stop());
break;
}
if (state_ == kGdsStopped) {
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
}
return Status::OK();
}
Status GraphDataServer::Stop() {
#if !defined(_WIN32) && !defined(_WIN64)
async_server_->Stop();
#endif
set_state(kGdsStopped);
graph_data_impl_.reset();
return Status::OK();
}
Status GraphDataServer::ClientRegister(int32_t pid) {
std::unique_lock<std::mutex> lck(mutex_);
MS_LOG(INFO) << "client register pid:" << std::to_string(pid);
client_pid_.emplace(pid);
if (client_pid_.size() > max_connected_client_num_) {
max_connected_client_num_ = client_pid_.size();
}
return Status::OK();
}
Status GraphDataServer::ClientUnRegister(int32_t pid) {
std::unique_lock<std::mutex> lck(mutex_);
auto itr = client_pid_.find(pid);
if (itr != client_pid_.end()) {
client_pid_.erase(itr);
MS_LOG(INFO) << "client unregister pid:" << std::to_string(pid);
}
return Status::OK();
}
} // namespace gnn
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,196 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_
#include <memory>
#include <mutex>
#include <string>
#include <unordered_set>
#if !defined(_WIN32) && !defined(_WIN64)
#include "grpcpp/grpcpp.h"
#include "minddata/dataset/engine/gnn/graph_data_service_impl.h"
#include "minddata/dataset/engine/gnn/grpc_async_server.h"
#endif
#include "minddata/dataset/util/task_manager.h"
namespace mindspore {
namespace dataset {
namespace gnn {
class GraphDataImpl;
class GraphDataServer {
public:
enum ServerState { kGdsUninit = 0, kGdsInitializing, kGdsRunning, kGdsStopped };
GraphDataServer(const std::string &dataset_file, int32_t num_workers, const std::string &hostname, int32_t port,
int32_t client_num, bool auto_shutdown);
~GraphDataServer() = default;
Status Init();
Status Stop();
Status ClientRegister(int32_t pid);
Status ClientUnRegister(int32_t pid);
enum ServerState state() { return state_; }
bool IsStoped() {
if (state_ == kGdsStopped) {
return true;
} else {
return false;
}
}
private:
void set_state(enum ServerState state) { state_ = state; }
Status InitGraphDataImpl();
#if !defined(_WIN32) && !defined(_WIN64)
Status StartAsyncRpcService();
#endif
Status JudgeAutoShutdownServer();
std::string dataset_file_;
int32_t num_workers_; // The number of worker threads
int32_t client_num_;
int32_t max_connected_client_num_;
bool auto_shutdown_;
enum ServerState state_;
std::unique_ptr<TaskGroup> tg_; // Class for worker management
std::unique_ptr<GraphDataImpl> graph_data_impl_;
std::unordered_set<int32_t> client_pid_;
std::mutex mutex_;
#if !defined(_WIN32) && !defined(_WIN64)
std::unique_ptr<GraphDataServiceImpl> service_impl_;
std::unique_ptr<GrpcAsyncServer> async_server_;
#endif
};
#if !defined(_WIN32) && !defined(_WIN64)
class UntypedCall {
public:
virtual ~UntypedCall() {}
virtual Status operator()() = 0;
};
template <class ServiceImpl, class AsyncService, class RequestMessage, class ResponseMessage>
class CallData : public UntypedCall {
public:
enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 };
using EnqueueFunction = void (AsyncService::*)(grpc::ServerContext *, RequestMessage *,
grpc::ServerAsyncResponseWriter<ResponseMessage> *,
grpc::CompletionQueue *, grpc::ServerCompletionQueue *, void *);
using HandleRequestFunction = grpc::Status (ServiceImpl::*)(grpc::ServerContext *, const RequestMessage *,
ResponseMessage *);
CallData(ServiceImpl *service_impl, AsyncService *async_service, grpc::ServerCompletionQueue *cq,
EnqueueFunction enqueue_function, HandleRequestFunction handle_request_function)
: status_(STATE::CREATE),
service_impl_(service_impl),
async_service_(async_service),
cq_(cq),
enqueue_function_(enqueue_function),
handle_request_function_(handle_request_function),
responder_(&ctx_) {}
~CallData() = default;
static Status EnqueueRequest(ServiceImpl *service_impl, AsyncService *async_service, grpc::ServerCompletionQueue *cq,
EnqueueFunction enqueue_function, HandleRequestFunction handle_request_function) {
auto call = new CallData<ServiceImpl, AsyncService, RequestMessage, ResponseMessage>(
service_impl, async_service, cq, enqueue_function, handle_request_function);
RETURN_IF_NOT_OK((*call)());
return Status::OK();
}
Status operator()() {
if (status_ == STATE::CREATE) {
status_ = STATE::PROCESS;
(async_service_->*enqueue_function_)(&ctx_, &request_, &responder_, cq_, cq_, this);
} else if (status_ == STATE::PROCESS) {
EnqueueRequest(service_impl_, async_service_, cq_, enqueue_function_, handle_request_function_);
status_ = STATE::FINISH;
// new CallData(service_, cq_, this->s_type_);
grpc::Status s = (service_impl_->*handle_request_function_)(&ctx_, &request_, &response_);
responder_.Finish(response_, s, this);
} else {
GPR_ASSERT(status_ == STATE::FINISH);
delete this;
}
return Status::OK();
}
private:
STATE status_;
ServiceImpl *service_impl_;
AsyncService *async_service_;
grpc::ServerCompletionQueue *cq_;
EnqueueFunction enqueue_function_;
HandleRequestFunction handle_request_function_;
grpc::ServerContext ctx_;
grpc::ServerAsyncResponseWriter<ResponseMessage> responder_;
RequestMessage request_;
ResponseMessage response_;
};
#define ENQUEUE_REQUEST(service_impl, async_service, cq, method, request_msg, response_msg) \
do { \
Status s = \
CallData<gnn::GraphDataServiceImpl, GnnGraphData::AsyncService, request_msg, response_msg>::EnqueueRequest( \
service_impl, async_service, cq, &GnnGraphData::AsyncService::Request##method, \
&gnn::GraphDataServiceImpl::method); \
RETURN_IF_NOT_OK(s); \
} while (0)
class GraphDataGrpcServer : public GrpcAsyncServer {
public:
GraphDataGrpcServer(const std::string &host, int32_t port, GraphDataServiceImpl *service_impl)
: GrpcAsyncServer(host, port), service_impl_(service_impl) {}
Status RegisterService(grpc::ServerBuilder *builder) {
builder->RegisterService(&svc_);
return Status::OK();
}
Status EnqueueRequest() {
ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), ClientRegister, GnnClientRegisterRequestPb,
GnnClientRegisterResponsePb);
ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), ClientUnRegister, GnnClientUnRegisterRequestPb,
GnnClientUnRegisterResponsePb);
ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), GetGraphData, GnnGraphDataRequestPb, GnnGraphDataResponsePb);
ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), GetMetaInfo, GnnMetaInfoRequestPb, GnnMetaInfoResponsePb);
return Status::OK();
}
Status ProcessRequest(void *tag) {
auto rq = static_cast<UntypedCall *>(tag);
RETURN_IF_NOT_OK((*rq)());
return Status::OK();
}
private:
GraphDataServiceImpl *service_impl_;
GnnGraphData::AsyncService svc_;
};
#endif
} // namespace gnn
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_

View File

@ -0,0 +1,299 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/gnn/graph_data_service_impl.h"
#include <algorithm>
#include <unordered_map>
#include <vector>
#include "minddata/dataset/engine/gnn/tensor_proto.h"
#include "minddata/dataset/engine/gnn/graph_data_server.h"
namespace mindspore {
namespace dataset {
namespace gnn {
using pFunction = Status (GraphDataServiceImpl::*)(const GnnGraphDataRequestPb *, GnnGraphDataResponsePb *);
static std::unordered_map<uint32_t, pFunction> g_get_graph_data_func_ = {
{GET_ALL_NODES, &GraphDataServiceImpl::GetAllNodes},
{GET_ALL_EDGES, &GraphDataServiceImpl::GetAllEdges},
{GET_NODES_FROM_EDGES, &GraphDataServiceImpl::GetNodesFromEdges},
{GET_ALL_NEIGHBORS, &GraphDataServiceImpl::GetAllNeighbors},
{GET_SAMPLED_NEIGHBORS, &GraphDataServiceImpl::GetSampledNeighbors},
{GET_NEG_SAMPLED_NEIGHBORS, &GraphDataServiceImpl::GetNegSampledNeighbors},
{RANDOM_WALK, &GraphDataServiceImpl::RandomWalk},
{GET_NODE_FEATURE, &GraphDataServiceImpl::GetNodeFeature},
{GET_EDGE_FEATURE, &GraphDataServiceImpl::GetEdgeFeature}};
GraphDataServiceImpl::GraphDataServiceImpl(GraphDataServer *server, GraphDataImpl *graph_data_impl)
: server_(server), graph_data_impl_(graph_data_impl) {}
Status GraphDataServiceImpl::FillDefaultFeature(GnnClientRegisterResponsePb *response) {
const auto default_node_features = graph_data_impl_->GetAllDefaultNodeFeatures();
for (const auto feature : *default_node_features) {
GnnFeatureInfoPb *feature_info = response->add_default_node_feature();
feature_info->set_type(feature.first);
RETURN_IF_NOT_OK(TensorToPb(feature.second->Value(), feature_info->mutable_feature()));
}
const auto default_edge_features = graph_data_impl_->GetAllDefaultEdgeFeatures();
for (const auto feature : *default_edge_features) {
GnnFeatureInfoPb *feature_info = response->add_default_edge_feature();
feature_info->set_type(feature.first);
RETURN_IF_NOT_OK(TensorToPb(feature.second->Value(), feature_info->mutable_feature()));
}
return Status::OK();
}
grpc::Status GraphDataServiceImpl::ClientRegister(grpc::ServerContext *context,
const GnnClientRegisterRequestPb *request,
GnnClientRegisterResponsePb *response) {
Status s = server_->ClientRegister(request->pid());
if (s.IsOk()) {
switch (server_->state()) {
case GraphDataServer::kGdsUninit:
case GraphDataServer::kGdsInitializing:
response->set_error_msg("Initializing");
break;
case GraphDataServer::kGdsRunning:
response->set_error_msg("Success");
response->set_data_schema(graph_data_impl_->GetDataSchema());
response->set_shared_memory_key(graph_data_impl_->GetSharedMemoryKey());
response->set_shared_memory_size(graph_data_impl_->GetSharedMemorySize());
s = FillDefaultFeature(response);
if (!s.IsOk()) {
response->set_error_msg(s.ToString());
}
break;
case GraphDataServer::kGdsStopped:
response->set_error_msg("Stoped");
break;
}
} else {
response->set_error_msg(s.ToString());
}
return ::grpc::Status::OK;
}
grpc::Status GraphDataServiceImpl::ClientUnRegister(grpc::ServerContext *context,
const GnnClientUnRegisterRequestPb *request,
GnnClientUnRegisterResponsePb *response) {
Status s = server_->ClientUnRegister(request->pid());
if (s.IsOk()) {
response->set_error_msg("Success");
} else {
response->set_error_msg(s.ToString());
}
return ::grpc::Status::OK;
}
grpc::Status GraphDataServiceImpl::GetGraphData(grpc::ServerContext *context, const GnnGraphDataRequestPb *request,
GnnGraphDataResponsePb *response) {
// MS_LOG(INFO) << "#### receive GetGraphData:" << request->op_name();
Status s;
auto iter = g_get_graph_data_func_.find(request->op_name());
if (iter != g_get_graph_data_func_.end()) {
pFunction func = iter->second;
s = (this->*func)(request, response);
if (s.IsOk()) {
response->set_error_msg("Success");
} else {
response->set_error_msg(s.ToString());
}
} else {
response->set_error_msg("Invalid op name.");
}
// MS_LOG(INFO) << "#### end receive GetGraphData:" << request->op_name();
return ::grpc::Status::OK;
}
grpc::Status GraphDataServiceImpl::GetMetaInfo(grpc::ServerContext *context, const GnnMetaInfoRequestPb *request,
GnnMetaInfoResponsePb *response) {
MetaInfo meta_info;
Status s = graph_data_impl_->GetMetaInfo(&meta_info);
if (s.IsOk()) {
response->set_error_msg("Success");
for (const auto &type : meta_info.node_type) {
auto node_info = response->add_node_info();
node_info->set_type(static_cast<google::protobuf::int32>(type));
auto itr = meta_info.node_num.find(type);
if (itr != meta_info.node_num.end()) {
node_info->set_num(static_cast<google::protobuf::int32>(itr->second));
} else {
node_info->set_num(0);
}
}
for (const auto &type : meta_info.edge_type) {
auto edge_info = response->add_edge_info();
edge_info->set_type(static_cast<google::protobuf::int32>(type));
auto itr = meta_info.edge_num.find(type);
if (itr != meta_info.edge_num.end()) {
edge_info->set_num(static_cast<google::protobuf::int32>(itr->second));
} else {
edge_info->set_num(0);
}
}
for (const auto &type : meta_info.node_feature_type) {
response->add_node_feature_type(static_cast<google::protobuf::int32>(type));
}
for (const auto &type : meta_info.edge_feature_type) {
response->add_edge_feature_type(static_cast<google::protobuf::int32>(type));
}
} else {
response->set_error_msg(s.ToString());
}
return ::grpc::Status::OK;
}
Status GraphDataServiceImpl::GetAllNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) {
CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of edge types is not 1");
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(graph_data_impl_->GetAllNodes(static_cast<NodeType>(request->type()[0]), &tensor));
TensorPb *result = response->add_result_data();
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
return Status::OK();
}
Status GraphDataServiceImpl::GetAllEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) {
CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of edge types is not 1");
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(graph_data_impl_->GetAllEdges(static_cast<EdgeType>(request->type()[0]), &tensor));
TensorPb *result = response->add_result_data();
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
return Status::OK();
}
Status GraphDataServiceImpl::GetNodesFromEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) {
CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input edge id is empty");
std::vector<EdgeIdType> edge_list;
edge_list.resize(request->id().size());
std::transform(request->id().begin(), request->id().end(), edge_list.begin(),
[](const google::protobuf::int32 id) { return static_cast<EdgeIdType>(id); });
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(graph_data_impl_->GetNodesFromEdges(edge_list, &tensor));
TensorPb *result = response->add_result_data();
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
return Status::OK();
}
Status GraphDataServiceImpl::GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) {
CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty");
CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of edge types is not 1");
std::vector<NodeIdType> node_list;
node_list.resize(request->id().size());
std::transform(request->id().begin(), request->id().end(), node_list.begin(),
[](const google::protobuf::int32 id) { return static_cast<NodeIdType>(id); });
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(graph_data_impl_->GetAllNeighbors(node_list, static_cast<NodeType>(request->type()[0]), &tensor));
TensorPb *result = response->add_result_data();
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
return Status::OK();
}
Status GraphDataServiceImpl::GetSampledNeighbors(const GnnGraphDataRequestPb *request,
GnnGraphDataResponsePb *response) {
CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty");
CHECK_FAIL_RETURN_UNEXPECTED(request->number_size() > 0, "The input neighbor number is empty");
CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() > 0, "The input neighbor type is empty");
std::vector<NodeIdType> node_list;
node_list.resize(request->id().size());
std::transform(request->id().begin(), request->id().end(), node_list.begin(),
[](const google::protobuf::int32 id) { return static_cast<NodeIdType>(id); });
std::vector<NodeIdType> neighbor_nums;
neighbor_nums.resize(request->number().size());
std::transform(request->number().begin(), request->number().end(), neighbor_nums.begin(),
[](const google::protobuf::int32 num) { return static_cast<NodeIdType>(num); });
std::vector<NodeType> neighbor_types;
neighbor_types.resize(request->type().size());
std::transform(request->type().begin(), request->type().end(), neighbor_types.begin(),
[](const google::protobuf::int32 type) { return static_cast<NodeType>(type); });
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(graph_data_impl_->GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &tensor));
TensorPb *result = response->add_result_data();
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
return Status::OK();
}
Status GraphDataServiceImpl::GetNegSampledNeighbors(const GnnGraphDataRequestPb *request,
GnnGraphDataResponsePb *response) {
CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty");
CHECK_FAIL_RETURN_UNEXPECTED(request->number_size() == 1, "The number of neighbor number is not 1");
CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of neighbor types is not 1");
std::vector<NodeIdType> node_list;
node_list.resize(request->id().size());
std::transform(request->id().begin(), request->id().end(), node_list.begin(),
[](const google::protobuf::int32 id) { return static_cast<NodeIdType>(id); });
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(graph_data_impl_->GetNegSampledNeighbors(node_list, static_cast<NodeIdType>(request->number()[0]),
static_cast<NodeType>(request->type()[0]), &tensor));
TensorPb *result = response->add_result_data();
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
return Status::OK();
}
Status GraphDataServiceImpl::RandomWalk(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) {
CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty");
CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() > 0, "The input meta path is empty");
std::vector<NodeIdType> node_list;
node_list.resize(request->id().size());
std::transform(request->id().begin(), request->id().end(), node_list.begin(),
[](const google::protobuf::int32 id) { return static_cast<NodeIdType>(id); });
std::vector<NodeType> meta_path;
meta_path.resize(request->type().size());
std::transform(request->type().begin(), request->type().end(), meta_path.begin(),
[](const google::protobuf::int32 type) { return static_cast<NodeType>(type); });
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(graph_data_impl_->RandomWalk(node_list, meta_path, request->random_walk().p(),
request->random_walk().q(), request->random_walk().default_id(),
&tensor));
TensorPb *result = response->add_result_data();
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
return Status::OK();
}
Status GraphDataServiceImpl::GetNodeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) {
std::shared_ptr<Tensor> nodes;
RETURN_IF_NOT_OK(PbToTensor(&request->id_tensor(), &nodes));
for (const auto &type : request->type()) {
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(graph_data_impl_->GetNodeFeatureSharedMemory(nodes, type, &tensor));
TensorPb *result = response->add_result_data();
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
}
return Status::OK();
}
Status GraphDataServiceImpl::GetEdgeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) {
std::shared_ptr<Tensor> edges;
RETURN_IF_NOT_OK(PbToTensor(&request->id_tensor(), &edges));
for (const auto &type : request->type()) {
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(graph_data_impl_->GetEdgeFeatureSharedMemory(edges, type, &tensor));
TensorPb *result = response->add_result_data();
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
}
return Status::OK();
}
} // namespace gnn
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,70 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_
#include <memory>
#include <string>
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
#include "proto/gnn_graph_data.grpc.pb.h"
#include "proto/gnn_graph_data.pb.h"
namespace mindspore {
namespace dataset {
namespace gnn {
class GraphDataServer;
// class GraphDataServiceImpl : public GnnGraphData::Service {
class GraphDataServiceImpl {
public:
GraphDataServiceImpl(GraphDataServer *server, GraphDataImpl *graph_data_impl);
~GraphDataServiceImpl() = default;
grpc::Status ClientRegister(grpc::ServerContext *context, const GnnClientRegisterRequestPb *request,
GnnClientRegisterResponsePb *response);
grpc::Status ClientUnRegister(grpc::ServerContext *context, const GnnClientUnRegisterRequestPb *request,
GnnClientUnRegisterResponsePb *response);
grpc::Status GetGraphData(grpc::ServerContext *context, const GnnGraphDataRequestPb *request,
GnnGraphDataResponsePb *response);
grpc::Status GetMetaInfo(grpc::ServerContext *context, const GnnMetaInfoRequestPb *request,
GnnMetaInfoResponsePb *response);
Status GetAllNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetAllEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetNodesFromEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetNegSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status RandomWalk(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetNodeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
Status GetEdgeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
private:
Status FillDefaultFeature(GnnClientRegisterResponsePb *response);
GraphDataServer *server_;
GraphDataImpl *graph_data_impl_;
};
} // namespace gnn
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_

View File

@ -0,0 +1,106 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/gnn/graph_feature_parser.h"
#include <memory>
#include <utility>
#include "mindspore/ccsrc/minddata/mindrecord/include/shard_error.h"
namespace mindspore {
namespace dataset {
namespace gnn {
using mindrecord::MSRStatus;
GraphFeatureParser::GraphFeatureParser(const ShardColumn &shard_column) {
shard_column_ = std::make_unique<ShardColumn>(shard_column);
}
Status GraphFeatureParser::LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &col_blob,
std::shared_ptr<Tensor> *tensor) {
const unsigned char *data = nullptr;
std::unique_ptr<unsigned char[]> data_ptr;
uint64_t n_bytes = 0, col_type_size = 1;
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
std::vector<int64_t> column_shape;
MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
&col_type_size, &column_shape);
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key);
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(std::move(TensorShape({static_cast<dsize_t>(n_bytes / col_type_size)})),
std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])),
data, tensor));
return Status::OK();
}
#if !defined(_WIN32) && !defined(_WIN64)
Status GraphFeatureParser::LoadFeatureToSharedMemory(const std::string &key, const std::vector<uint8_t> &col_blob,
GraphSharedMemory *shared_memory,
std::shared_ptr<Tensor> *out_tensor) {
const unsigned char *data = nullptr;
std::unique_ptr<unsigned char[]> data_ptr;
uint64_t n_bytes = 0, col_type_size = 1;
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
std::vector<int64_t> column_shape;
MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
&col_type_size, &column_shape);
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key);
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(Tensor::CreateEmpty(std::move(TensorShape({2})), std::move(DataType(DataType::DE_INT64)), &tensor));
auto fea_itr = tensor->begin<int64_t>();
int64_t offset = 0;
RETURN_IF_NOT_OK(shared_memory->InsertData(data, n_bytes, &offset));
*fea_itr = offset;
++fea_itr;
*fea_itr = n_bytes;
*out_tensor = std::move(tensor);
return Status::OK();
}
#endif
Status GraphFeatureParser::LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &col_blob,
std::vector<int32_t> *indices) {
const unsigned char *data = nullptr;
std::unique_ptr<unsigned char[]> data_ptr;
uint64_t n_bytes = 0, col_type_size = 1;
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
std::vector<int64_t> column_shape;
MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
&col_type_size, &column_shape);
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key);
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
for (int i = 0; i < n_bytes; i += col_type_size) {
int32_t feature_ind = -1;
if (col_type == mindrecord::ColumnInt32) {
feature_ind = *(reinterpret_cast<const int32_t *>(data + i));
} else if (col_type == mindrecord::ColumnInt64) {
feature_ind = *(reinterpret_cast<const int64_t *>(data + i));
} else {
RETURN_STATUS_UNEXPECTED("Feature Index needs to be int32/int64 type!");
}
if (feature_ind >= 0) indices->push_back(feature_ind);
}
return Status::OK();
}
} // namespace gnn
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,67 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_
#include <memory>
#include <queue>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "minddata/dataset/core/data_type.h"
#include "minddata/dataset/core/tensor.h"
#if !defined(_WIN32) && !defined(_WIN64)
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
#endif
#include "minddata/dataset/engine/gnn/feature.h"
#include "minddata/dataset/util/status.h"
#include "minddata/mindrecord/include/shard_column.h"
namespace mindspore {
namespace dataset {
namespace gnn {
using mindrecord::ShardColumn;
class GraphFeatureParser {
public:
explicit GraphFeatureParser(const ShardColumn &shard_column);
~GraphFeatureParser() = default;
// @param std::string key - column name
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
// @param std::vector<int32_t> *ind - return value, list of feature index in int32_t
// @return Status - the status code
Status LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &blob, std::vector<int32_t> *ind);
// @param std::string &key - column name
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
// @param std::shared_ptr<Tensor> *tensor - return value feature tensor
// @return Status - the status code
Status LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &blob, std::shared_ptr<Tensor> *tensor);
#if !defined(_WIN32) && !defined(_WIN64)
Status LoadFeatureToSharedMemory(const std::string &key, const std::vector<uint8_t> &col_blob,
GraphSharedMemory *shared_memory, std::shared_ptr<Tensor> *out_tensor);
#endif
private:
std::unique_ptr<ShardColumn> shard_column_;
};
} // namespace gnn
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_

View File

@ -13,41 +13,42 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "minddata/dataset/engine/gnn/graph_loader.h"
#include <future> #include <future>
#include <tuple> #include <tuple>
#include <utility> #include <utility>
#include "minddata/dataset/engine/gnn/graph_loader.h" #include "minddata/dataset/engine/gnn/graph_data_impl.h"
#include "mindspore/ccsrc/minddata/mindrecord/include/shard_error.h"
#include "minddata/dataset/engine/gnn/local_edge.h" #include "minddata/dataset/engine/gnn/local_edge.h"
#include "minddata/dataset/engine/gnn/local_node.h" #include "minddata/dataset/engine/gnn/local_node.h"
#include "minddata/dataset/util/task_manager.h" #include "minddata/dataset/util/task_manager.h"
#include "minddata/mindrecord/include/shard_error.h"
using ShardTuple = std::vector<std::tuple<std::vector<uint8_t>, mindspore::mindrecord::json>>; using ShardTuple = std::vector<std::tuple<std::vector<uint8_t>, mindspore::mindrecord::json>>;
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
namespace gnn { namespace gnn {
using mindrecord::MSRStatus; using mindrecord::MSRStatus;
GraphLoader::GraphLoader(std::string mr_filepath, int32_t num_workers) GraphLoader::GraphLoader(GraphDataImpl *graph_impl, std::string mr_filepath, int32_t num_workers, bool server_mode)
: mr_path_(mr_filepath), : graph_impl_(graph_impl),
mr_path_(mr_filepath),
num_workers_(num_workers), num_workers_(num_workers),
row_id_(0), row_id_(0),
shard_reader_(nullptr), shard_reader_(nullptr),
graph_feature_parser_(nullptr),
keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {} keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {}
Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map, Status GraphLoader::GetNodesAndEdges() {
EdgeTypeMap *e_type_map, NodeFeatureMap *n_feature_map, NodeIdMap *n_id_map = &graph_impl_->node_id_map_;
EdgeFeatureMap *e_feature_map, DefaultNodeFeatureMap *default_node_feature_map, EdgeIdMap *e_id_map = &graph_impl_->edge_id_map_;
DefaultEdgeFeatureMap *default_edge_feature_map) {
for (std::deque<std::shared_ptr<Node>> &dq : n_deques_) { for (std::deque<std::shared_ptr<Node>> &dq : n_deques_) {
while (dq.empty() == false) { while (dq.empty() == false) {
std::shared_ptr<Node> node_ptr = dq.front(); std::shared_ptr<Node> node_ptr = dq.front();
n_id_map->insert({node_ptr->id(), node_ptr}); n_id_map->insert({node_ptr->id(), node_ptr});
(*n_type_map)[node_ptr->type()].push_back(node_ptr->id()); graph_impl_->node_type_map_[node_ptr->type()].push_back(node_ptr->id());
dq.pop_front(); dq.pop_front();
} }
} }
@ -63,15 +64,15 @@ Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, N
RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second})); RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second}));
RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second)); RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second));
e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_ e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_
(*e_type_map)[edge_ptr->type()].push_back(edge_ptr->id()); graph_impl_->edge_type_map_[edge_ptr->type()].push_back(edge_ptr->id());
dq.pop_front(); dq.pop_front();
} }
} }
for (auto &itr : *n_type_map) itr.second.shrink_to_fit(); for (auto &itr : graph_impl_->node_type_map_) itr.second.shrink_to_fit();
for (auto &itr : *e_type_map) itr.second.shrink_to_fit(); for (auto &itr : graph_impl_->edge_type_map_) itr.second.shrink_to_fit();
MergeFeatureMaps(n_feature_map, e_feature_map, default_node_feature_map, default_edge_feature_map); MergeFeatureMaps();
return Status::OK(); return Status::OK();
} }
@ -92,13 +93,26 @@ Status GraphLoader::InitAndLoad() {
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetShardHeader()->GetSchemaCount() > 0, "No schema found!"); CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetShardHeader()->GetSchemaCount() > 0, "No schema found!");
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Launch(true) == MSRStatus::SUCCESS, "fail to launch mr"); CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Launch(true) == MSRStatus::SUCCESS, "fail to launch mr");
mindrecord::json schema = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema())["schema"]; graph_impl_->data_schema_ = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema());
mindrecord::json schema = graph_impl_->data_schema_["schema"];
for (const std::string &key : keys_) { for (const std::string &key : keys_) {
if (schema.find(key) == schema.end()) { if (schema.find(key) == schema.end()) {
RETURN_STATUS_UNEXPECTED(key + ":doesn't exist in schema:" + schema.dump()); RETURN_STATUS_UNEXPECTED(key + ":doesn't exist in schema:" + schema.dump());
} }
} }
if (graph_impl_->server_mode_) {
#if !defined(_WIN32) && !defined(_WIN64)
int64_t total_blob_size = 0;
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetTotalBlobSize(&total_blob_size) == MSRStatus::SUCCESS,
"failed to get total blob size");
graph_impl_->graph_shared_memory_ = std::make_unique<GraphSharedMemory>(total_blob_size, mr_path_);
RETURN_IF_NOT_OK(graph_impl_->graph_shared_memory_->CreateSharedMemory());
#endif
}
graph_feature_parser_ = std::make_unique<GraphFeatureParser>(*shard_reader_->GetShardColumn());
// launching worker threads // launching worker threads
for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) { for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) {
RETURN_IF_NOT_OK(vg.CreateAsyncTask("GraphLoader", std::bind(&GraphLoader::WorkerEntry, this, wkr_id))); RETURN_IF_NOT_OK(vg.CreateAsyncTask("GraphLoader", std::bind(&GraphLoader::WorkerEntry, this, wkr_id)));
@ -116,18 +130,39 @@ Status GraphLoader::LoadNode(const std::vector<uint8_t> &col_blob, const mindrec
NodeType node_type = static_cast<NodeType>(col_jsn["type"]); NodeType node_type = static_cast<NodeType>(col_jsn["type"]);
(*node) = std::make_shared<LocalNode>(node_id, node_type); (*node) = std::make_shared<LocalNode>(node_id, node_type);
std::vector<int32_t> indices; std::vector<int32_t> indices;
RETURN_IF_NOT_OK(LoadFeatureIndex("node_feature_index", col_blob, col_jsn, &indices)); RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("node_feature_index", col_blob, &indices));
if (graph_impl_->server_mode_) {
for (int32_t ind : indices) { #if !defined(_WIN32) && !defined(_WIN64)
std::shared_ptr<Tensor> tensor; for (int32_t ind : indices) {
RETURN_IF_NOT_OK(LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor)); std::shared_ptr<Tensor> tensor_sm;
RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared<Feature>(ind, tensor))); RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureToSharedMemory(
(*feature_map)[node_type].insert(ind); "node_feature_" + std::to_string(ind), col_blob, graph_impl_->graph_shared_memory_.get(), &tensor_sm));
if ((*default_feature)[ind] == nullptr) { RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared<Feature>(ind, tensor_sm, true)));
std::shared_ptr<Tensor> zero_tensor; (*feature_map)[node_type].insert(ind);
RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor)); if ((*default_feature)[ind] == nullptr) {
RETURN_IF_NOT_OK(zero_tensor->Zero()); std::shared_ptr<Tensor> tensor;
(*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor); RETURN_IF_NOT_OK(
graph_feature_parser_->LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, &tensor));
std::shared_ptr<Tensor> zero_tensor;
RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor));
RETURN_IF_NOT_OK(zero_tensor->Zero());
(*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
}
}
#endif
} else {
for (int32_t ind : indices) {
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(
graph_feature_parser_->LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, &tensor));
RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared<Feature>(ind, tensor)));
(*feature_map)[node_type].insert(ind);
if ((*default_feature)[ind] == nullptr) {
std::shared_ptr<Tensor> zero_tensor;
RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor));
RETURN_IF_NOT_OK(zero_tensor->Zero());
(*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
}
} }
} }
return Status::OK(); return Status::OK();
@ -143,63 +178,42 @@ Status GraphLoader::LoadEdge(const std::vector<uint8_t> &col_blob, const mindrec
std::shared_ptr<Node> dst = std::make_shared<LocalNode>(dst_id, -1); std::shared_ptr<Node> dst = std::make_shared<LocalNode>(dst_id, -1);
(*edge) = std::make_shared<LocalEdge>(edge_id, edge_type, src, dst); (*edge) = std::make_shared<LocalEdge>(edge_id, edge_type, src, dst);
std::vector<int32_t> indices; std::vector<int32_t> indices;
RETURN_IF_NOT_OK(LoadFeatureIndex("edge_feature_index", col_blob, col_jsn, &indices)); RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("edge_feature_index", col_blob, &indices));
for (int32_t ind : indices) { if (graph_impl_->server_mode_) {
std::shared_ptr<Tensor> tensor; #if !defined(_WIN32) && !defined(_WIN64)
RETURN_IF_NOT_OK(LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor)); for (int32_t ind : indices) {
RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared<Feature>(ind, tensor))); std::shared_ptr<Tensor> tensor_sm;
(*feature_map)[edge_type].insert(ind); RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureToSharedMemory(
if ((*default_feature)[ind] == nullptr) { "edge_feature_" + std::to_string(ind), col_blob, graph_impl_->graph_shared_memory_.get(), &tensor_sm));
std::shared_ptr<Tensor> zero_tensor; RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared<Feature>(ind, tensor_sm, true)));
RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor)); (*feature_map)[edge_type].insert(ind);
RETURN_IF_NOT_OK(zero_tensor->Zero()); if ((*default_feature)[ind] == nullptr) {
(*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor); std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(
graph_feature_parser_->LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, &tensor));
std::shared_ptr<Tensor> zero_tensor;
RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor));
RETURN_IF_NOT_OK(zero_tensor->Zero());
(*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
}
}
#endif
} else {
for (int32_t ind : indices) {
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(
graph_feature_parser_->LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, &tensor));
RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared<Feature>(ind, tensor)));
(*feature_map)[edge_type].insert(ind);
if ((*default_feature)[ind] == nullptr) {
std::shared_ptr<Tensor> zero_tensor;
RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor));
RETURN_IF_NOT_OK(zero_tensor->Zero());
(*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
}
} }
} }
return Status::OK();
}
Status GraphLoader::LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &col_blob,
const mindrecord::json &col_jsn, std::shared_ptr<Tensor> *tensor) {
const unsigned char *data = nullptr;
std::unique_ptr<unsigned char[]> data_ptr;
uint64_t n_bytes = 0, col_type_size = 1;
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
std::vector<int64_t> column_shape;
MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName(
key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape);
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key);
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(std::move(TensorShape({static_cast<dsize_t>(n_bytes / col_type_size)})),
std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])),
data, tensor));
return Status::OK();
}
Status GraphLoader::LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &col_blob,
const mindrecord::json &col_jsn, std::vector<int32_t> *indices) {
const unsigned char *data = nullptr;
std::unique_ptr<unsigned char[]> data_ptr;
uint64_t n_bytes = 0, col_type_size = 1;
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
std::vector<int64_t> column_shape;
MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName(
key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape);
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key);
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
for (int i = 0; i < n_bytes; i += col_type_size) {
int32_t feature_ind = -1;
if (col_type == mindrecord::ColumnInt32) {
feature_ind = *(reinterpret_cast<const int32_t *>(data + i));
} else if (col_type == mindrecord::ColumnInt64) {
feature_ind = *(reinterpret_cast<const int64_t *>(data + i));
} else {
RETURN_STATUS_UNEXPECTED("Feature Index needs to be int32/int64 type!");
}
if (feature_ind >= 0) indices->push_back(feature_ind);
}
return Status::OK(); return Status::OK();
} }
@ -234,21 +248,19 @@ Status GraphLoader::WorkerEntry(int32_t worker_id) {
return Status::OK(); return Status::OK();
} }
void GraphLoader::MergeFeatureMaps(NodeFeatureMap *n_feature_map, EdgeFeatureMap *e_feature_map, void GraphLoader::MergeFeatureMaps() {
DefaultNodeFeatureMap *default_node_feature_map,
DefaultEdgeFeatureMap *default_edge_feature_map) {
for (int wkr_id = 0; wkr_id < num_workers_; wkr_id++) { for (int wkr_id = 0; wkr_id < num_workers_; wkr_id++) {
for (auto &m : n_feature_maps_[wkr_id]) { for (auto &m : n_feature_maps_[wkr_id]) {
for (auto &n : m.second) (*n_feature_map)[m.first].insert(n); for (auto &n : m.second) graph_impl_->node_feature_map_[m.first].insert(n);
} }
for (auto &m : e_feature_maps_[wkr_id]) { for (auto &m : e_feature_maps_[wkr_id]) {
for (auto &n : m.second) (*e_feature_map)[m.first].insert(n); for (auto &n : m.second) graph_impl_->edge_feature_map_[m.first].insert(n);
} }
for (auto &m : default_node_feature_maps_[wkr_id]) { for (auto &m : default_node_feature_maps_[wkr_id]) {
(*default_node_feature_map)[m.first] = m.second; graph_impl_->default_node_feature_map_[m.first] = m.second;
} }
for (auto &m : default_edge_feature_maps_[wkr_id]) { for (auto &m : default_edge_feature_maps_[wkr_id]) {
(*default_edge_feature_map)[m.first] = m.second; graph_impl_->default_edge_feature_map_[m.first] = m.second;
} }
} }
n_feature_maps_.clear(); n_feature_maps_.clear();

View File

@ -26,10 +26,13 @@
#include "minddata/dataset/core/data_type.h" #include "minddata/dataset/core/data_type.h"
#include "minddata/dataset/core/tensor.h" #include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/engine/gnn/feature.h"
#include "minddata/dataset/engine/gnn/graph.h"
#include "minddata/dataset/engine/gnn/node.h"
#include "minddata/dataset/engine/gnn/edge.h" #include "minddata/dataset/engine/gnn/edge.h"
#include "minddata/dataset/engine/gnn/feature.h"
#include "minddata/dataset/engine/gnn/graph_feature_parser.h"
#if !defined(_WIN32) && !defined(_WIN64)
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
#endif
#include "minddata/dataset/engine/gnn/node.h"
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
#include "minddata/mindrecord/include/shard_reader.h" #include "minddata/mindrecord/include/shard_reader.h"
namespace mindspore { namespace mindspore {
@ -46,13 +49,15 @@ using EdgeFeatureMap = std::unordered_map<EdgeType, std::unordered_set<FeatureTy
using DefaultNodeFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>; using DefaultNodeFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>;
using DefaultEdgeFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>; using DefaultEdgeFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>;
class GraphDataImpl;
// this class interfaces with the underlying storage format (mindrecord) // this class interfaces with the underlying storage format (mindrecord)
// it returns raw nodes and edges via GetNodesAndEdges // it returns raw nodes and edges via GetNodesAndEdges
// it is then the responsibility of graph to construct itself based on the nodes and edges // it is then the responsibility of graph to construct itself based on the nodes and edges
// if needed, this class could become a base where each derived class handles a specific storage format // if needed, this class could become a base where each derived class handles a specific storage format
class GraphLoader { class GraphLoader {
public: public:
explicit GraphLoader(std::string mr_filepath, int32_t num_workers = 4); GraphLoader(GraphDataImpl *graph_impl, std::string mr_filepath, int32_t num_workers = 4, bool server_mode = false);
~GraphLoader() = default; ~GraphLoader() = default;
// Init mindrecord and load everything into memory multi-threaded // Init mindrecord and load everything into memory multi-threaded
@ -63,8 +68,7 @@ class GraphLoader {
// nodes and edges are added to map without any connection. That's because there nodes and edges are read in // nodes and edges are added to map without any connection. That's because there nodes and edges are read in
// random order. src_node and dst_node in Edge are node_id only with -1 as type. // random order. src_node and dst_node in Edge are node_id only with -1 as type.
// features attached to each node and edge are expected to be filled correctly // features attached to each node and edge are expected to be filled correctly
Status GetNodesAndEdges(NodeIdMap *, EdgeIdMap *, NodeTypeMap *, EdgeTypeMap *, NodeFeatureMap *, EdgeFeatureMap *, Status GetNodesAndEdges();
DefaultNodeFeatureMap *, DefaultEdgeFeatureMap *);
private: private:
// //
@ -92,29 +96,15 @@ class GraphLoader {
Status LoadEdge(const std::vector<uint8_t> &blob, const mindrecord::json &jsn, std::shared_ptr<Edge> *edge, Status LoadEdge(const std::vector<uint8_t> &blob, const mindrecord::json &jsn, std::shared_ptr<Edge> *edge,
EdgeFeatureMap *feature_map, DefaultEdgeFeatureMap *default_feature); EdgeFeatureMap *feature_map, DefaultEdgeFeatureMap *default_feature);
// @param std::string key - column name
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
// @param mindrecord::json &jsn - contains raw data
// @param std::vector<int32_t> *ind - return value, list of feature index in int32_t
// @return Status - the status code
Status LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &blob, const mindrecord::json &jsn,
std::vector<int32_t> *ind);
// @param std::string &key - column name
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
// @param mindrecord::json &jsn - contains raw data
// @param std::shared_ptr<Tensor> *tensor - return value feature tensor
// @return Status - the status code
Status LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &blob, const mindrecord::json &jsn,
std::shared_ptr<Tensor> *tensor);
// merge NodeFeatureMap and EdgeFeatureMap of each worker into 1 // merge NodeFeatureMap and EdgeFeatureMap of each worker into 1
void MergeFeatureMaps(NodeFeatureMap *, EdgeFeatureMap *, DefaultNodeFeatureMap *, DefaultEdgeFeatureMap *); void MergeFeatureMaps();
GraphDataImpl *graph_impl_;
std::string mr_path_;
const int32_t num_workers_; const int32_t num_workers_;
std::atomic_int row_id_; std::atomic_int row_id_;
std::string mr_path_;
std::unique_ptr<ShardReader> shard_reader_; std::unique_ptr<ShardReader> shard_reader_;
std::unique_ptr<GraphFeatureParser> graph_feature_parser_;
std::vector<std::deque<std::shared_ptr<Node>>> n_deques_; std::vector<std::deque<std::shared_ptr<Node>>> n_deques_;
std::vector<std::deque<std::shared_ptr<Edge>>> e_deques_; std::vector<std::deque<std::shared_ptr<Edge>>> e_deques_;
std::vector<NodeFeatureMap> n_feature_maps_; std::vector<NodeFeatureMap> n_feature_maps_;

View File

@ -0,0 +1,134 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
#include <string>
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
namespace gnn {
GraphSharedMemory::GraphSharedMemory(int64_t memory_size, key_t memory_key)
: memory_size_(memory_size),
memory_key_(memory_key),
memory_ptr_(nullptr),
memory_offset_(0),
is_new_create_(false) {
std::stringstream stream;
stream << std::hex << memory_key_;
memory_key_str_ = stream.str();
}
GraphSharedMemory::GraphSharedMemory(int64_t memory_size, const std::string &mr_file)
: mr_file_(mr_file),
memory_size_(memory_size),
memory_key_(-1),
memory_ptr_(nullptr),
memory_offset_(0),
is_new_create_(false) {}
GraphSharedMemory::~GraphSharedMemory() {
if (is_new_create_) {
(void)DeleteSharedMemory();
}
}
Status GraphSharedMemory::CreateSharedMemory() {
if (memory_key_ == -1) {
// ftok to generate unique key
memory_key_ = ftok(mr_file_.data(), kGnnSharedMemoryId);
CHECK_FAIL_RETURN_UNEXPECTED(memory_key_ != -1, "Failed to get key of shared memory. file_name:" + mr_file_);
std::stringstream stream;
stream << std::hex << memory_key_;
memory_key_str_ = stream.str();
}
int shmflg = (0666 | IPC_CREAT | IPC_EXCL);
Status s = SharedMemoryImpl(shmflg);
if (s.IsOk()) {
is_new_create_ = true;
MS_LOG(INFO) << "Create shared memory success, key=0x" << memory_key_str_;
} else {
MS_LOG(WARNING) << "Shared memory with the same key may already exist, key=0x" << memory_key_str_;
shmflg = (0666 | IPC_CREAT);
s = SharedMemoryImpl(shmflg);
if (!s.IsOk()) {
RETURN_STATUS_UNEXPECTED("Create shared memory fao;ed, key=0x" + memory_key_str_);
}
}
return Status::OK();
}
Status GraphSharedMemory::GetSharedMemory() {
int shmflg = 0;
RETURN_IF_NOT_OK(SharedMemoryImpl(shmflg));
return Status::OK();
}
Status GraphSharedMemory::DeleteSharedMemory() {
int shmid = shmget(memory_key_, 0, 0);
CHECK_FAIL_RETURN_UNEXPECTED(shmid != -1, "Failed to get shared memory. key=0x" + memory_key_str_);
int result = shmctl(shmid, IPC_RMID, 0);
CHECK_FAIL_RETURN_UNEXPECTED(result != -1, "Failed to delete shared memory. key=0x" + memory_key_str_);
return Status::OK();
}
Status GraphSharedMemory::SharedMemoryImpl(const int &shmflg) {
// shmget returns an identifier in shmid
int shmid = shmget(memory_key_, memory_size_, shmflg);
CHECK_FAIL_RETURN_UNEXPECTED(shmid != -1, "Failed to get shared memory. key=0x" + memory_key_str_);
// shmat to attach to shared memory
auto data = shmat(shmid, reinterpret_cast<void *>(0), 0);
CHECK_FAIL_RETURN_UNEXPECTED(data != (char *)(-1), "Failed to address shared memory. key=0x" + memory_key_str_);
memory_ptr_ = reinterpret_cast<uint8_t *>(data);
return Status::OK();
}
Status GraphSharedMemory::InsertData(const uint8_t *data, int64_t len, int64_t *offset) {
CHECK_FAIL_RETURN_UNEXPECTED(data, "Input data is nullptr.");
CHECK_FAIL_RETURN_UNEXPECTED(len > 0, "Input len is invalid.");
std::lock_guard<std::mutex> lck(mutex_);
CHECK_FAIL_RETURN_UNEXPECTED((memory_size_ - memory_offset_ >= len),
"Insufficient shared memory space to insert data.");
if (EOK != memcpy_s(memory_ptr_ + memory_offset_, memory_size_ - memory_offset_, data, len)) {
RETURN_STATUS_UNEXPECTED("Failed to insert data into shared memory.");
}
*offset = memory_offset_;
memory_offset_ += len;
return Status::OK();
}
Status GraphSharedMemory::GetData(uint8_t *data, int64_t data_len, int64_t offset, int64_t get_data_len) {
CHECK_FAIL_RETURN_UNEXPECTED(data, "Input data is nullptr.");
CHECK_FAIL_RETURN_UNEXPECTED(get_data_len > 0, "Input get_data_len is invalid.");
CHECK_FAIL_RETURN_UNEXPECTED(data_len >= get_data_len, "Insufficient target address space.");
CHECK_FAIL_RETURN_UNEXPECTED(memory_size_ >= get_data_len + offset,
"get_data_len is too large, beyond the space of shared memory.");
if (EOK != memcpy_s(data, data_len, memory_ptr_ + offset, get_data_len)) {
RETURN_STATUS_UNEXPECTED("Failed to insert data into shared memory.");
}
return Status::OK();
}
} // namespace gnn
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,72 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_
#include <sys/ipc.h>
#include <sys/shm.h>
#include <mutex>
#include <string>
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
namespace gnn {
const int kGnnSharedMemoryId = 65;
class GraphSharedMemory {
public:
explicit GraphSharedMemory(int64_t memory_size, key_t memory_key);
explicit GraphSharedMemory(int64_t memory_size, const std::string &mr_file);
~GraphSharedMemory();
// @param uint8_t** shared_memory - shared memory address
// @return Status - the status code
Status CreateSharedMemory();
// @param uint8_t** shared_memory - shared memory address
// @return Status - the status code
Status GetSharedMemory();
Status DeleteSharedMemory();
Status InsertData(const uint8_t *data, int64_t len, int64_t *offset);
Status GetData(uint8_t *data, int64_t data_len, int64_t offset, int64_t get_data_len);
key_t memory_key() { return memory_key_; }
int64_t memory_size() { return memory_size_; }
private:
Status SharedMemoryImpl(const int &shmflg);
std::string mr_file_;
int64_t memory_size_;
key_t memory_key_;
std::string memory_key_str_;
uint8_t *memory_ptr_;
int64_t memory_offset_;
std::mutex mutex_;
bool is_new_create_;
};
} // namespace gnn
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_

View File

@ -0,0 +1,82 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/gnn/grpc_async_server.h"
#include <limits>
#include "minddata/dataset/util/task_manager.h"
#include "utils/log_adapter.h"
namespace mindspore {
namespace dataset {
GrpcAsyncServer::GrpcAsyncServer(const std::string &host, int32_t port) : host_(host), port_(port) {}
GrpcAsyncServer::~GrpcAsyncServer() { Stop(); }
Status GrpcAsyncServer::Run() {
std::string server_address = host_ + ":" + std::to_string(port_);
grpc::ServerBuilder builder;
// Default message size for gRPC is 4MB. Increase it to 2g-1
builder.SetMaxReceiveMessageSize(std::numeric_limits<int32_t>::max());
builder.AddChannelArgument(GRPC_ARG_ALLOW_REUSEPORT, 0);
int port_tcpip = 0;
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_tcpip);
RETURN_IF_NOT_OK(RegisterService(&builder));
cq_ = builder.AddCompletionQueue();
server_ = builder.BuildAndStart();
if (server_) {
MS_LOG(INFO) << "Server listening on " << server_address;
} else {
std::string errMsg = "Fail to start server. ";
if (port_tcpip != port_) {
errMsg += "Unable to bind to address " + server_address + ".";
}
RETURN_STATUS_UNEXPECTED(errMsg);
}
return Status::OK();
}
Status GrpcAsyncServer::HandleRequest() {
bool success;
void *tag;
// We loop through the grpc queue. Each connection if successful
// will come back with our own tag which is an instance of CallData
// and we simply call its functor. But first we need to create these instances
// and inject them into the grpc queue.
RETURN_IF_NOT_OK(EnqueueRequest());
while (cq_->Next(&tag, &success)) {
RETURN_IF_INTERRUPTED();
if (success) {
RETURN_IF_NOT_OK(ProcessRequest(tag));
} else {
MS_LOG(DEBUG) << "cq_->Next failed.";
}
}
return Status::OK();
}
void GrpcAsyncServer::Stop() {
if (server_) {
server_->Shutdown();
}
// Always shutdown the completion queue after the server.
if (cq_) {
cq_->Shutdown();
}
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,59 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "grpcpp/grpcpp.h"
#include "grpcpp/impl/codegen/async_unary_call.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
/// \brief Async server base class
class GrpcAsyncServer {
public:
explicit GrpcAsyncServer(const std::string &host, int32_t port);
virtual ~GrpcAsyncServer();
/// \brief Brings up gRPC server
/// \return none
Status Run();
/// \brief Entry function to handle async server request
Status HandleRequest();
void Stop();
virtual Status RegisterService(grpc::ServerBuilder *builder) = 0;
virtual Status EnqueueRequest() = 0;
virtual Status ProcessRequest(void *tag) = 0;
protected:
int32_t port_;
std::string host_;
std::unique_ptr<grpc::ServerCompletionQueue> cq_;
std::unique_ptr<grpc::Server> server_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_

View File

@ -44,6 +44,7 @@ Status LocalEdge::UpdateFeature(const std::shared_ptr<Feature> &feature) {
return Status::OK(); return Status::OK();
} }
} }
} // namespace gnn } // namespace gnn
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -20,10 +20,10 @@
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/engine/gnn/edge.h" #include "minddata/dataset/engine/gnn/edge.h"
#include "minddata/dataset/engine/gnn/feature.h" #include "minddata/dataset/engine/gnn/feature.h"
#include "minddata/dataset/engine/gnn/node.h" #include "minddata/dataset/engine/gnn/node.h"
#include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {

View File

@ -20,9 +20,9 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/engine/gnn/node.h" #include "minddata/dataset/engine/gnn/node.h"
#include "minddata/dataset/engine/gnn/feature.h" #include "minddata/dataset/engine/gnn/feature.h"
#include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {

View File

@ -20,8 +20,8 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "minddata/dataset/util/status.h"
#include "minddata/dataset/engine/gnn/feature.h" #include "minddata/dataset/engine/gnn/feature.h"
#include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {

View File

@ -0,0 +1,84 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/engine/gnn/tensor_proto.h"
#include <algorithm>
#include <utility>
#include <unordered_map>
namespace mindspore {
namespace dataset {
const std::unordered_map<DataTypePb, DataType::Type> g_pb2datatype_map{
{DataTypePb::DE_PB_UNKNOWN, DataType::DE_UNKNOWN}, {DataTypePb::DE_PB_BOOL, DataType::DE_BOOL},
{DataTypePb::DE_PB_INT8, DataType::DE_INT8}, {DataTypePb::DE_PB_UINT8, DataType::DE_UINT8},
{DataTypePb::DE_PB_INT16, DataType::DE_INT16}, {DataTypePb::DE_PB_UINT16, DataType::DE_UINT16},
{DataTypePb::DE_PB_INT32, DataType::DE_INT32}, {DataTypePb::DE_PB_UINT32, DataType::DE_UINT32},
{DataTypePb::DE_PB_INT64, DataType::DE_INT64}, {DataTypePb::DE_PB_UINT64, DataType::DE_UINT64},
{DataTypePb::DE_PB_FLOAT16, DataType::DE_FLOAT16}, {DataTypePb::DE_PB_FLOAT32, DataType::DE_FLOAT32},
{DataTypePb::DE_PB_FLOAT64, DataType::DE_FLOAT64}, {DataTypePb::DE_PB_STRING, DataType::DE_STRING},
};
const std::unordered_map<DataType::Type, DataTypePb> g_datatype2pb_map{
{DataType::DE_UNKNOWN, DataTypePb::DE_PB_UNKNOWN}, {DataType::DE_BOOL, DataTypePb::DE_PB_BOOL},
{DataType::DE_INT8, DataTypePb::DE_PB_INT8}, {DataType::DE_UINT8, DataTypePb::DE_PB_UINT8},
{DataType::DE_INT16, DataTypePb::DE_PB_INT16}, {DataType::DE_UINT16, DataTypePb::DE_PB_UINT16},
{DataType::DE_INT32, DataTypePb::DE_PB_INT32}, {DataType::DE_UINT32, DataTypePb::DE_PB_UINT32},
{DataType::DE_INT64, DataTypePb::DE_PB_INT64}, {DataType::DE_UINT64, DataTypePb::DE_PB_UINT64},
{DataType::DE_FLOAT16, DataTypePb::DE_PB_FLOAT16}, {DataType::DE_FLOAT32, DataTypePb::DE_PB_FLOAT32},
{DataType::DE_FLOAT64, DataTypePb::DE_PB_FLOAT64}, {DataType::DE_STRING, DataTypePb::DE_PB_STRING},
};
Status TensorToPb(const std::shared_ptr<Tensor> tensor, TensorPb *tensor_pb) {
CHECK_FAIL_RETURN_UNEXPECTED(tensor, "Parameter tensor is a null pointer");
CHECK_FAIL_RETURN_UNEXPECTED(tensor_pb, "Parameter tensor_pb is a null pointer");
std::vector<dsize_t> shape = tensor->shape().AsVector();
for (auto dim : shape) {
tensor_pb->add_dims(static_cast<google::protobuf::int64>(dim));
}
auto iter = g_datatype2pb_map.find(tensor->type().value());
if (iter == g_datatype2pb_map.end()) {
RETURN_STATUS_UNEXPECTED("Invalid tensor type: " + tensor->type().ToString());
}
tensor_pb->set_tensor_type(iter->second);
tensor_pb->set_data(tensor->GetBuffer(), tensor->SizeInBytes());
return Status::OK();
}
Status PbToTensor(const TensorPb *tensor_pb, std::shared_ptr<Tensor> *tensor) {
CHECK_FAIL_RETURN_UNEXPECTED(tensor_pb, "Parameter tensor_pb is a null pointer");
CHECK_FAIL_RETURN_UNEXPECTED(tensor, "Parameter tensor is a null pointer");
std::vector<dsize_t> shape;
shape.resize(tensor_pb->dims().size());
std::transform(tensor_pb->dims().begin(), tensor_pb->dims().end(), shape.begin(),
[](const google::protobuf::int64 dim) { return static_cast<dsize_t>(dim); });
auto iter = g_pb2datatype_map.find(tensor_pb->tensor_type());
if (iter == g_pb2datatype_map.end()) {
RETURN_STATUS_UNEXPECTED("Invalid Tensor_pb type: " + std::to_string(tensor_pb->tensor_type()));
}
DataType::Type type = iter->second;
std::shared_ptr<Tensor> tensor_out;
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(TensorShape(shape), DataType(type),
reinterpret_cast<const unsigned char *>(tensor_pb->data().data()),
tensor_pb->data().size(), &tensor_out));
*tensor = std::move(tensor_out);
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,36 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_TENSOR_PROTO_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_TENSOR_PROTO_H_
#include <deque>
#include <memory>
#include <vector>
#include "proto/gnn_tensor.pb.h"
#include "minddata/dataset/core/tensor.h"
namespace mindspore {
namespace dataset {
Status TensorToPb(const std::shared_ptr<Tensor> tensor, TensorPb *tensor_pb);
Status PbToTensor(const TensorPb *tensor_pb, std::shared_ptr<Tensor> *tensor);
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_TENSOR_PROTO_H_

View File

@ -61,6 +61,7 @@ const std::unordered_map<std::string, ColumnDataType> ColumnDataTypeMap = {
class ShardColumn { class ShardColumn {
public: public:
explicit ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer = true); explicit ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer = true);
explicit ShardColumn(const json &schema_json, bool compress_integer = true);
~ShardColumn() = default; ~ShardColumn() = default;
@ -72,23 +73,29 @@ class ShardColumn {
std::vector<int64_t> *column_shape); std::vector<int64_t> *column_shape);
/// \brief compress blob /// \brief compress blob
std::vector<uint8_t> CompressBlob(const std::vector<uint8_t> &blob); std::vector<uint8_t> CompressBlob(const std::vector<uint8_t> &blob, int64_t *compression_size);
/// \brief check if blob compressed /// \brief check if blob compressed
bool CheckCompressBlob() const { return has_compress_blob_; } bool CheckCompressBlob() const { return has_compress_blob_; }
/// \brief getter
uint64_t GetNumBlobColumn() const { return num_blob_column_; } uint64_t GetNumBlobColumn() const { return num_blob_column_; }
/// \brief getter
std::vector<std::string> GetColumnName() { return column_name_; } std::vector<std::string> GetColumnName() { return column_name_; }
/// \brief getter
std::vector<ColumnDataType> GeColumnDataType() { return column_data_type_; } std::vector<ColumnDataType> GeColumnDataType() { return column_data_type_; }
/// \brief getter
std::vector<std::vector<int64_t>> GetColumnShape() { return column_shape_; } std::vector<std::vector<int64_t>> GetColumnShape() { return column_shape_; }
/// \brief get column value from blob /// \brief get column value from blob
MSRStatus GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob, MSRStatus GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr, const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr,
uint64_t *const n_bytes); uint64_t *const n_bytes);
/// \brief get column type
std::pair<MSRStatus, ColumnCategory> GetColumnTypeByName(const std::string &column_name, std::pair<MSRStatus, ColumnCategory> GetColumnTypeByName(const std::string &column_name,
ColumnDataType *column_data_type, ColumnDataType *column_data_type,
uint64_t *column_data_type_size, uint64_t *column_data_type_size,
@ -99,6 +106,9 @@ class ShardColumn {
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes); std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes);
private: private:
/// \brief intialization
void Init(const json &schema_json, bool compress_integer = true);
/// \brief get float value from json /// \brief get float value from json
template <typename T> template <typename T>
MSRStatus GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, bool use_double); MSRStatus GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, bool use_double);

View File

@ -65,6 +65,11 @@ class ShardHeader {
/// \return the Statistic /// \return the Statistic
std::vector<std::shared_ptr<Statistics>> GetStatistics(); std::vector<std::shared_ptr<Statistics>> GetStatistics();
/// \brief add the statistic and save it
/// \param[in] statistic info of slim size
/// \return null
int64_t GetSlimSizeStatistic(const json &slim_size_json);
/// \brief get the fields of the index /// \brief get the fields of the index
/// \return the fields of the index /// \return the fields of the index
std::vector<std::pair<uint64_t, std::string>> GetFields(); std::vector<std::pair<uint64_t, std::string>> GetFields();
@ -114,10 +119,14 @@ class ShardHeader {
uint64_t GetPageSize() const { return page_size_; } uint64_t GetPageSize() const { return page_size_; }
uint64_t GetCompressionSize() const { return compression_size_; }
void SetHeaderSize(const uint64_t &header_size) { header_size_ = header_size; } void SetHeaderSize(const uint64_t &header_size) { header_size_ = header_size; }
void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; } void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; }
void SetCompressionSize(const uint64_t &compression_size) { compression_size_ = compression_size; }
std::vector<std::string> SerializeHeader(); std::vector<std::string> SerializeHeader();
MSRStatus PagesToFile(const std::string dump_file_name); MSRStatus PagesToFile(const std::string dump_file_name);
@ -177,6 +186,7 @@ class ShardHeader {
uint32_t shard_count_; uint32_t shard_count_;
uint64_t header_size_; uint64_t header_size_;
uint64_t page_size_; uint64_t page_size_;
uint64_t compression_size_;
std::shared_ptr<Index> index_; std::shared_ptr<Index> index_;
std::vector<std::string> shard_addresses_; std::vector<std::string> shard_addresses_;

View File

@ -209,6 +209,9 @@ class ShardReader {
/// \brief get all classes /// \brief get all classes
MSRStatus GetAllClasses(const std::string &category_field, std::set<std::string> &categories); MSRStatus GetAllClasses(const std::string &category_field, std::set<std::string> &categories);
/// \brief get the size of blob data
MSRStatus GetTotalBlobSize(int64_t *total_blob_size);
protected: protected:
/// \brief sqlite call back function /// \brief sqlite call back function
static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names); static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names);
@ -323,6 +326,7 @@ class ShardReader {
const std::string kThreadName = "THRD_ITER_"; // prefix of thread name const std::string kThreadName = "THRD_ITER_"; // prefix of thread name
std::vector<std::thread> thread_set_; // thread list std::vector<std::thread> thread_set_; // thread list
int num_rows_; // number of rows int num_rows_; // number of rows
int64_t total_blob_size_; // total size of blob data
std::mutex mtx_delivery_; // locker for delivery std::mutex mtx_delivery_; // locker for delivery
std::condition_variable cv_delivery_; // conditional variable for delivery std::condition_variable cv_delivery_; // conditional variable for delivery
std::condition_variable cv_iterator_; // conditional variable for iterator std::condition_variable cv_iterator_; // conditional variable for iterator

View File

@ -257,6 +257,7 @@ class ShardWriter {
std::mutex check_mutex_; // mutex for data check std::mutex check_mutex_; // mutex for data check
std::atomic<bool> flag_{false}; std::atomic<bool> flag_{false};
std::atomic<int64_t> compression_size_;
}; };
} // namespace mindrecord } // namespace mindrecord
} // namespace mindspore } // namespace mindspore

View File

@ -43,6 +43,7 @@ ShardReader::ShardReader() {
page_size_ = 0; page_size_ = 0;
header_size_ = 0; header_size_ = 0;
num_rows_ = 0; num_rows_ = 0;
total_blob_size_ = 0;
num_padded_ = 0; num_padded_ = 0;
} }
@ -55,9 +56,11 @@ std::pair<MSRStatus, std::vector<std::string>> ShardReader::GetMeta(const std::s
return {FAILED, {}}; return {FAILED, {}};
} }
auto header = ret.second; auto header = ret.second;
meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]}, uint64_t compression_size = header.contains("compression_size") ? header["compression_size"].get<uint64_t>() : 0;
{"version", header["version"]}, {"index_fields", header["index_fields"]}, meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]},
{"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}}; {"compression_size", compression_size}, {"version", header["version"]},
{"index_fields", header["index_fields"]}, {"schema", header["schema"]},
{"blob_fields", header["blob_fields"]}};
return {SUCCESS, header["shard_addresses"]}; return {SUCCESS, header["shard_addresses"]};
} }
@ -145,6 +148,11 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa
for (const auto &rg : row_group_summary) { for (const auto &rg : row_group_summary) {
num_rows_ += std::get<3>(rg); num_rows_ += std::get<3>(rg);
} }
auto disk_size = page_size_ * row_group_summary.size();
auto compression_size = shard_header_->GetCompressionSize();
total_blob_size_ = disk_size + compression_size;
MS_LOG(INFO) << "Blob data size, on disk: " << disk_size << " , addtional uncompression: " << compression_size
<< " , Total: " << total_blob_size_;
MS_LOG(INFO) << "Get meta from mindrecord file & index file successfully."; MS_LOG(INFO) << "Get meta from mindrecord file & index file successfully.";
@ -272,6 +280,11 @@ std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummar
return row_group_summary; return row_group_summary;
} }
MSRStatus ShardReader::GetTotalBlobSize(int64_t *total_blob_size) {
*total_blob_size = total_blob_size_;
return SUCCESS;
}
MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels, MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels,
std::shared_ptr<std::fstream> fs, std::shared_ptr<std::fstream> fs,
std::vector<std::vector<std::vector<uint64_t>>> &offsets, int shard_id, std::vector<std::vector<std::vector<uint64_t>>> &offsets, int shard_id,

View File

@ -28,11 +28,9 @@ using mindspore::MsLogLevel::INFO;
namespace mindspore { namespace mindspore {
namespace mindrecord { namespace mindrecord {
ShardWriter::ShardWriter() ShardWriter::ShardWriter()
: shard_count_(1), : shard_count_(1), header_size_(kDefaultHeaderSize), page_size_(kDefaultPageSize), row_count_(0), schema_count_(1) {
header_size_(kDefaultHeaderSize), compression_size_ = 0;
page_size_(kDefaultPageSize), }
row_count_(0),
schema_count_(1) {}
ShardWriter::~ShardWriter() { ShardWriter::~ShardWriter() {
for (int i = static_cast<int>(file_streams_.size()) - 1; i >= 0; i--) { for (int i = static_cast<int>(file_streams_.size()) - 1; i >= 0; i--) {
@ -201,6 +199,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
if (ret == FAILED) { if (ret == FAILED) {
return FAILED; return FAILED;
} }
compression_size_ = shard_header_->GetCompressionSize();
ret = Open(real_addresses, true); ret = Open(real_addresses, true);
if (ret == FAILED) { if (ret == FAILED) {
MS_LOG(ERROR) << "Open file failed"; MS_LOG(ERROR) << "Open file failed";
@ -614,7 +613,9 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>
// compress blob // compress blob
if (shard_column_->CheckCompressBlob()) { if (shard_column_->CheckCompressBlob()) {
for (auto &blob : blob_data) { for (auto &blob : blob_data) {
blob = shard_column_->CompressBlob(blob); int64_t compression_bytes = 0;
blob = shard_column_->CompressBlob(blob, &compression_bytes);
compression_size_ += compression_bytes;
} }
} }
@ -1177,6 +1178,11 @@ MSRStatus ShardWriter::WriteShardHeader() {
MS_LOG(ERROR) << "Shard header is null"; MS_LOG(ERROR) << "Shard header is null";
return FAILED; return FAILED;
} }
int64_t compression_temp = compression_size_;
uint64_t compression_size = compression_temp > 0 ? compression_temp : 0;
shard_header_->SetCompressionSize(compression_size);
auto shard_header = shard_header_->SerializeHeader(); auto shard_header = shard_header_->SerializeHeader();
// Write header data to multi files // Write header data to multi files
if (shard_count_ > static_cast<int>(file_streams_.size()) || shard_count_ > static_cast<int>(shard_header.size())) { if (shard_count_ > static_cast<int>(file_streams_.size()) || shard_count_ > static_cast<int>(shard_header.size())) {

View File

@ -24,7 +24,15 @@ namespace mindspore {
namespace mindrecord { namespace mindrecord {
ShardColumn::ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer) { ShardColumn::ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer) {
auto first_schema = shard_header->GetSchemas()[0]; auto first_schema = shard_header->GetSchemas()[0];
auto schema = first_schema->GetSchema()["schema"]; json schema_json = first_schema->GetSchema();
Init(schema_json, compress_integer);
}
ShardColumn::ShardColumn(const json &schema_json, bool compress_integer) { Init(schema_json, compress_integer); }
void ShardColumn::Init(const json &schema_json, bool compress_integer) {
auto schema = schema_json["schema"];
auto blob_fields = schema_json["blob_fields"];
bool has_integer_array = false; bool has_integer_array = false;
for (json::iterator it = schema.begin(); it != schema.end(); ++it) { for (json::iterator it = schema.begin(); it != schema.end(); ++it) {
@ -52,8 +60,6 @@ ShardColumn::ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool
column_name_id_[column_name_[i]] = i; column_name_id_[column_name_[i]] = i;
} }
auto blob_fields = first_schema->GetBlobFields();
for (const auto &field : blob_fields) { for (const auto &field : blob_fields) {
blob_column_.push_back(field); blob_column_.push_back(field);
} }
@ -282,8 +288,9 @@ ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) {
return it_blob == blob_column_id_.end() ? ColumnInRaw : ColumnInBlob; return it_blob == blob_column_id_.end() ? ColumnInRaw : ColumnInBlob;
} }
std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob) { std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob, int64_t *compression_size) {
// Skip if no compress columns // Skip if no compress columns
*compression_size = 0;
if (!CheckCompressBlob()) return blob; if (!CheckCompressBlob()) return blob;
std::vector<uint8_t> dst_blob; std::vector<uint8_t> dst_blob;
@ -295,7 +302,9 @@ std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob)
// Compress and return is blob has 1 column only // Compress and return is blob has 1 column only
if (num_blob_column_ == 1) { if (num_blob_column_ == 1) {
return CompressInt(blob, int_type); dst_blob = CompressInt(blob, int_type);
*compression_size = static_cast<int64_t>(blob.size()) - static_cast<int64_t>(dst_blob.size());
return dst_blob;
} }
// Just copy and continue if column dat type is not int32/int64 // Just copy and continue if column dat type is not int32/int64
@ -319,6 +328,7 @@ std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob)
i_src += kInt64Len + num_bytes; i_src += kInt64Len + num_bytes;
} }
MS_LOG(DEBUG) << "Compress all blob from " << blob.size() << " to " << dst_blob.size() << "."; MS_LOG(DEBUG) << "Compress all blob from " << blob.size() << " to " << dst_blob.size() << ".";
*compression_size = static_cast<int64_t>(blob.size()) - static_cast<int64_t>(dst_blob.size());
return dst_blob; return dst_blob;
} }

View File

@ -33,7 +33,9 @@ using mindspore::MsLogLevel::ERROR;
namespace mindspore { namespace mindspore {
namespace mindrecord { namespace mindrecord {
std::atomic<bool> thread_status(false); std::atomic<bool> thread_status(false);
ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0) { index_ = std::make_shared<Index>(); } ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0), compression_size_(0) {
index_ = std::make_shared<Index>();
}
MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool load_dataset) { MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool load_dataset) {
shard_count_ = headers.size(); shard_count_ = headers.size();
@ -54,6 +56,7 @@ MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool l
ParseShardAddress(header["shard_addresses"]); ParseShardAddress(header["shard_addresses"]);
header_size_ = header["header_size"].get<uint64_t>(); header_size_ = header["header_size"].get<uint64_t>();
page_size_ = header["page_size"].get<uint64_t>(); page_size_ = header["page_size"].get<uint64_t>();
compression_size_ = header.contains("compression_size") ? header["compression_size"].get<uint64_t>() : 0;
} }
if (SUCCESS != ParsePage(header["page"], shard_index, load_dataset)) { if (SUCCESS != ParsePage(header["page"], shard_index, load_dataset)) {
return FAILED; return FAILED;
@ -146,9 +149,12 @@ std::pair<MSRStatus, json> ShardHeader::BuildSingleHeader(const std::string &fil
return {FAILED, json()}; return {FAILED, json()};
} }
json raw_header = ret.second; json raw_header = ret.second;
uint64_t compression_size =
raw_header.contains("compression_size") ? raw_header["compression_size"].get<uint64_t>() : 0;
json header = {{"shard_addresses", raw_header["shard_addresses"]}, json header = {{"shard_addresses", raw_header["shard_addresses"]},
{"header_size", raw_header["header_size"]}, {"header_size", raw_header["header_size"]},
{"page_size", raw_header["page_size"]}, {"page_size", raw_header["page_size"]},
{"compression_size", compression_size},
{"index_fields", raw_header["index_fields"]}, {"index_fields", raw_header["index_fields"]},
{"blob_fields", raw_header["schema"][0]["blob_fields"]}, {"blob_fields", raw_header["schema"][0]["blob_fields"]},
{"schema", raw_header["schema"][0]["schema"]}, {"schema", raw_header["schema"][0]["schema"]},
@ -343,6 +349,7 @@ std::vector<std::string> ShardHeader::SerializeHeader() {
s += "\"index_fields\":" + index + ","; s += "\"index_fields\":" + index + ",";
s += "\"page\":" + pages[shardId] + ","; s += "\"page\":" + pages[shardId] + ",";
s += "\"page_size\":" + std::to_string(page_size_) + ","; s += "\"page_size\":" + std::to_string(page_size_) + ",";
s += "\"compression_size\":" + std::to_string(compression_size_) + ",";
s += "\"schema\":" + schema + ","; s += "\"schema\":" + schema + ",";
s += "\"shard_addresses\":" + address + ","; s += "\"shard_addresses\":" + address + ",";
s += "\"shard_id\":" + std::to_string(shardId) + ","; s += "\"shard_id\":" + std::to_string(shardId) + ",";

View File

@ -3085,20 +3085,22 @@ def _cpp_sampler_fn(sampler, dataset):
yield tuple([np.array(x, copy=False) for x in val]) yield tuple([np.array(x, copy=False) for x in val])
def _cpp_sampler_fn_mp(sampler, dataset, num_worker): def _cpp_sampler_fn_mp(sampler, dataset, num_worker, multi_process):
""" """
Multiprocessing generator function wrapper for mappable dataset with cpp sampler. Multiprocessing generator function wrapper for mappable dataset with cpp sampler.
""" """
indices = sampler.get_indices() indices = sampler.get_indices()
return _sampler_fn_mp(indices, dataset, num_worker) sample_fn = SamplerFn(dataset, num_worker, multi_process)
return sample_fn.process(indices)
def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker): def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker, multi_process):
""" """
Multiprocessing generator function wrapper for mappable dataset with python sampler. Multiprocessing generator function wrapper for mappable dataset with python sampler.
""" """
indices = _fetch_py_sampler_indices(sampler, num_samples) indices = _fetch_py_sampler_indices(sampler, num_samples)
return _sampler_fn_mp(indices, dataset, num_worker) sample_fn = SamplerFn(dataset, num_worker, multi_process)
return sample_fn.process(indices)
def _fetch_py_sampler_indices(sampler, num_samples): def _fetch_py_sampler_indices(sampler, num_samples):
@ -3132,63 +3134,92 @@ def _fill_worker_indices(workers, indices, idx):
return idx return idx
def _sampler_fn_mp(indices, dataset, num_worker): class SamplerFn:
""" """
Multiprocessing generator function wrapper master process. Multiprocessing or multithread generator function wrapper master process.
""" """
workers = [] def __init__(self, dataset, num_worker, multi_process):
# Event for end of epoch self.workers = []
eoe = multiprocessing.Event() self.num_worker = num_worker
self.multi_process = multi_process
# Event for end of epoch
if multi_process is True:
self.eoe = multiprocessing.Event()
self.eof = multiprocessing.Event()
else:
self.eoe = threading.Event()
self.eof = threading.Event()
# Create workers
for _ in range(num_worker):
if multi_process is True:
worker = _GeneratorWorkerMp(dataset, self.eoe, self.eof)
else:
worker = _GeneratorWorkerMt(dataset, self.eoe, self.eof)
worker.daemon = True
self.workers.append(worker)
# Create workers def process(self, indices):
for _ in range(num_worker): """
worker = _GeneratorWorker(dataset, eoe) The main process, start the child process or child thread, and fill the index queue,
worker.daemon = True get the result from the result and return.
workers.append(worker) """
# Fill initial index queues
idx_cursor = 0
idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)
# Fill initial index queues # Start all workers
idx_cursor = 0 for w in self.workers:
idx_cursor = _fill_worker_indices(workers, indices, idx_cursor) w.start()
# Start all workers # Fetch results
for w in workers: for i in range(len(indices)):
w.start() # Fetch result and put index
try:
result = self.workers[i % self.num_worker].get()
except queue.Empty:
raise Exception("Generator worker process timeout")
except KeyboardInterrupt:
self.eof.set()
for w in self.workers:
w.terminate()
w.join()
raise Exception("Generator worker receives KeyboardInterrupt")
if idx_cursor < len(indices):
idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)
# Set eoe event once all indices are sent
if idx_cursor == len(indices) and not self.eoe.is_set():
self.eoe.set()
yield tuple([np.array(x, copy=False) for x in result])
# Fetch results def __del__(self):
for i in range(len(indices)): self.eoe.set()
# Fetch result and put index self.eof.set()
try: if self.multi_process is False:
result = workers[i % num_worker].get() for w in self.workers:
except queue.Empty:
raise Exception("Generator worker process timeout")
except KeyboardInterrupt:
for w in workers:
w.terminate()
w.join() w.join()
raise Exception("Generator worker receives KeyboardInterrupt")
if idx_cursor < len(indices):
idx_cursor = _fill_worker_indices(workers, indices, idx_cursor)
# Set eoe event once all indices are sent
if idx_cursor == len(indices) and not eoe.is_set():
eoe.set()
yield tuple([np.array(x, copy=False) for x in result])
def _generator_worker_loop(dataset, idx_queue, result_queue, eoe): def _generator_worker_loop(dataset, idx_queue, result_queue, eoe, eof):
""" """
Multiprocessing generator worker process loop. Multiprocessing or multithread generator worker process loop.
""" """
while True: while True:
# Fetch index, block # Fetch index, block
try: try:
idx = idx_queue.get() idx = idx_queue.get(timeout=10)
except KeyboardInterrupt: except KeyboardInterrupt:
raise Exception("Generator worker receives KeyboardInterrupt") raise Exception("Generator worker receives KeyboardInterrupt")
except queue.Empty:
if eof.is_set() or eoe.is_set():
raise Exception("Generator worker receives queue.Empty")
continue
if idx is None: if idx is None:
# When the queue is out of scope from master process, a None item can be fetched from the queue. # When the queue is out of scope from master process, a None item can be fetched from the queue.
# Upon receiving None, worker process should check if EOE is set. # Upon receiving None, worker process should check if EOE is set.
assert eoe.is_set(), "" assert eoe.is_set(), ""
return return
if eof.is_set():
return
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process # Fetch data, any exception from __getitem__ will terminate worker and timeout master process
result = dataset[idx] result = dataset[idx]
# Send data, block # Send data, block
@ -3197,17 +3228,19 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eoe):
except KeyboardInterrupt: except KeyboardInterrupt:
raise Exception("Generator worker receives KeyboardInterrupt") raise Exception("Generator worker receives KeyboardInterrupt")
del result, idx del result, idx
if eoe.is_set() and idx_queue.empty():
return
class _GeneratorWorker(multiprocessing.Process): class _GeneratorWorkerMt(threading.Thread):
""" """
Worker process for multiprocess Generator. Worker process for multithread Generator.
""" """
def __init__(self, dataset, eoe): def __init__(self, dataset, eoe, eof):
self.idx_queue = multiprocessing.Queue(16) self.idx_queue = queue.Queue(16)
self.res_queue = multiprocessing.Queue(16) self.res_queue = queue.Queue(16)
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe)) super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof))
def put(self, item): def put(self, item):
""" """
@ -3219,7 +3252,30 @@ class _GeneratorWorker(multiprocessing.Process):
""" """
Get function for worker result queue. Block with timeout. Get function for worker result queue. Block with timeout.
""" """
return self.res_queue.get() return self.res_queue.get(timeout=10)
class _GeneratorWorkerMp(multiprocessing.Process):
"""
Worker process for multiprocess Generator.
"""
def __init__(self, dataset, eoe, eof):
self.idx_queue = multiprocessing.Queue(16)
self.res_queue = multiprocessing.Queue(16)
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof))
def put(self, item):
"""
Put function for worker index queue. Never block. Raise queue.Full on failure.
"""
self.idx_queue.put_nowait(item)
def get(self):
"""
Get function for worker result queue. Block with timeout.
"""
return self.res_queue.get(timeout=10)
def __del__(self): def __del__(self):
self.terminate() self.terminate()
@ -3282,6 +3338,8 @@ class GeneratorDataset(MappableDataset):
When this argument is specified, 'num_samples' will not effect. Random accessible input is required. When this argument is specified, 'num_samples' will not effect. Random accessible input is required.
shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only
when num_shards is also specified. Random accessible input is required. when num_shards is also specified. Random accessible input is required.
python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This
option could be beneficial if the python operation is computational heavy (default=True).
Examples: Examples:
>>> import mindspore.dataset as ds >>> import mindspore.dataset as ds
@ -3318,12 +3376,14 @@ class GeneratorDataset(MappableDataset):
@check_generatordataset @check_generatordataset
def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None, def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None,
num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None): num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None,
python_multiprocessing=True):
super().__init__(num_parallel_workers) super().__init__(num_parallel_workers)
self.source = source self.source = source
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
self.num_samples = num_samples self.num_samples = num_samples
self.num_shards = num_shards self.num_shards = num_shards
self.python_multiprocessing = python_multiprocessing
if column_names is not None and not isinstance(column_names, list): if column_names is not None and not isinstance(column_names, list):
column_names = [column_names] column_names = [column_names]
@ -3405,12 +3465,16 @@ class GeneratorDataset(MappableDataset):
sampler_instance.set_num_rows(len(self.source)) sampler_instance.set_num_rows(len(self.source))
sampler_instance.initialize() sampler_instance.initialize()
if new_op.num_parallel_workers > 1: if new_op.num_parallel_workers > 1:
new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, self.source, new_op.num_parallel_workers)) new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, self.source,
new_op.num_parallel_workers,
self.python_multiprocessing))
else: else:
new_op.source = (lambda: _cpp_sampler_fn(sampler_instance, self.source)) new_op.source = (lambda: _cpp_sampler_fn(sampler_instance, self.source))
else: else:
if new_op.num_parallel_workers > 1: if new_op.num_parallel_workers > 1:
new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, self.source, new_op.num_parallel_workers)) new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, self.source,
new_op.num_parallel_workers,
self.python_multiprocessing))
else: else:
new_op.source = (lambda: _py_sampler_fn(new_op.sampler, new_op.num_samples, self.source)) new_op.source = (lambda: _py_sampler_fn(new_op.sampler, new_op.num_samples, self.source))
else: else:

View File

@ -16,8 +16,11 @@
graphdata.py supports loading graph dataset for GNN network training, graphdata.py supports loading graph dataset for GNN network training,
and provides operations related to graph data. and provides operations related to graph data.
""" """
import atexit
import time
import numpy as np import numpy as np
from mindspore._c_dataengine import Graph from mindspore._c_dataengine import GraphDataClient
from mindspore._c_dataengine import GraphDataServer
from mindspore._c_dataengine import Tensor from mindspore._c_dataengine import Tensor
from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \ from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \
@ -34,14 +37,52 @@ class GraphData:
dataset_file (str): One of file names in dataset. dataset_file (str): One of file names in dataset.
num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel
(default=None). (default=None).
working_mode (str, optional): Set working mode, now support 'local'/'client'/'server' (default='local').
- 'local', used in non-distributed training scenarios.
- 'client', used in distributed training scenarios, the client does not load data,
but obtains data from the server.
- 'server', used in distributed training scenarios, the server loads the data
and is available to the client.
hostname (str, optional): Valid when working_mode is set to 'client' or 'server',
set the hostname of the graph data server (default='127.0.0.1').
port (int, optional): Valid when working_mode is set to 'client' or 'server',
set the port of the graph data server, the range is 1024-65535 (default=50051).
num_client (int, optional): Valid when working_mode is set to 'server',
set the number of clients expected to connect, and the server will allocate corresponding
resources according to this parameter (default=1).
auto_shutdown (bool, optional): Valid when working_mode is set to 'server',
Control when all clients have connected and no client connected to the server,
automatically exit the server (default=True).
""" """
@check_gnn_graphdata @check_gnn_graphdata
def __init__(self, dataset_file, num_parallel_workers=None): def __init__(self, dataset_file, num_parallel_workers=None, working_mode='local', hostname='127.0.0.1', port=50051,
num_client=1, auto_shutdown=True):
self._dataset_file = dataset_file self._dataset_file = dataset_file
self._working_mode = working_mode
if num_parallel_workers is None: if num_parallel_workers is None:
num_parallel_workers = 1 num_parallel_workers = 1
self._graph = Graph(dataset_file, num_parallel_workers)
def stop():
self._graph_data.stop()
atexit.register(stop)
if working_mode in ['local', 'client']:
self._graph_data = GraphDataClient(dataset_file, num_parallel_workers, working_mode, hostname, port)
if working_mode == 'server':
self._graph_data = GraphDataServer(
dataset_file, num_parallel_workers, hostname, port, num_client, auto_shutdown)
try:
while self._graph_data.is_stoped() is not True:
time.sleep(1)
except KeyboardInterrupt:
# self._graph_data.stop()
raise Exception("Graph data server receives KeyboardInterrupt")
@check_gnn_get_all_nodes @check_gnn_get_all_nodes
def get_all_nodes(self, node_type): def get_all_nodes(self, node_type):
@ -62,7 +103,9 @@ class GraphData:
Raises: Raises:
TypeError: If `node_type` is not integer. TypeError: If `node_type` is not integer.
""" """
return self._graph.get_all_nodes(node_type).as_array() if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server")
return self._graph_data.get_all_nodes(node_type).as_array()
@check_gnn_get_all_edges @check_gnn_get_all_edges
def get_all_edges(self, edge_type): def get_all_edges(self, edge_type):
@ -83,7 +126,9 @@ class GraphData:
Raises: Raises:
TypeError: If `edge_type` is not integer. TypeError: If `edge_type` is not integer.
""" """
return self._graph.get_all_edges(edge_type).as_array() if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server")
return self._graph_data.get_all_edges(edge_type).as_array()
@check_gnn_get_nodes_from_edges @check_gnn_get_nodes_from_edges
def get_nodes_from_edges(self, edge_list): def get_nodes_from_edges(self, edge_list):
@ -99,7 +144,9 @@ class GraphData:
Raises: Raises:
TypeError: If `edge_list` is not list or ndarray. TypeError: If `edge_list` is not list or ndarray.
""" """
return self._graph.get_nodes_from_edges(edge_list).as_array() if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server")
return self._graph_data.get_nodes_from_edges(edge_list).as_array()
@check_gnn_get_all_neighbors @check_gnn_get_all_neighbors
def get_all_neighbors(self, node_list, neighbor_type): def get_all_neighbors(self, node_list, neighbor_type):
@ -123,7 +170,9 @@ class GraphData:
TypeError: If `node_list` is not list or ndarray. TypeError: If `node_list` is not list or ndarray.
TypeError: If `neighbor_type` is not integer. TypeError: If `neighbor_type` is not integer.
""" """
return self._graph.get_all_neighbors(node_list, neighbor_type).as_array() if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server")
return self._graph_data.get_all_neighbors(node_list, neighbor_type).as_array()
@check_gnn_get_sampled_neighbors @check_gnn_get_sampled_neighbors
def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types): def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types):
@ -155,7 +204,9 @@ class GraphData:
TypeError: If `neighbor_nums` is not list or ndarray. TypeError: If `neighbor_nums` is not list or ndarray.
TypeError: If `neighbor_types` is not list or ndarray. TypeError: If `neighbor_types` is not list or ndarray.
""" """
return self._graph.get_sampled_neighbors( if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server")
return self._graph_data.get_sampled_neighbors(
node_list, neighbor_nums, neighbor_types).as_array() node_list, neighbor_nums, neighbor_types).as_array()
@check_gnn_get_neg_sampled_neighbors @check_gnn_get_neg_sampled_neighbors
@ -182,7 +233,9 @@ class GraphData:
TypeError: If `neg_neighbor_num` is not integer. TypeError: If `neg_neighbor_num` is not integer.
TypeError: If `neg_neighbor_type` is not integer. TypeError: If `neg_neighbor_type` is not integer.
""" """
return self._graph.get_neg_sampled_neighbors( if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server")
return self._graph_data.get_neg_sampled_neighbors(
node_list, neg_neighbor_num, neg_neighbor_type).as_array() node_list, neg_neighbor_num, neg_neighbor_type).as_array()
@check_gnn_get_node_feature @check_gnn_get_node_feature
@ -207,10 +260,12 @@ class GraphData:
TypeError: If `node_list` is not list or ndarray. TypeError: If `node_list` is not list or ndarray.
TypeError: If `feature_types` is not list or ndarray. TypeError: If `feature_types` is not list or ndarray.
""" """
if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server")
if isinstance(node_list, list): if isinstance(node_list, list):
node_list = np.array(node_list, dtype=np.int32) node_list = np.array(node_list, dtype=np.int32)
return [ return [
t.as_array() for t in self._graph.get_node_feature( t.as_array() for t in self._graph_data.get_node_feature(
Tensor(node_list), Tensor(node_list),
feature_types)] feature_types)]
@ -236,10 +291,12 @@ class GraphData:
TypeError: If `edge_list` is not list or ndarray. TypeError: If `edge_list` is not list or ndarray.
TypeError: If `feature_types` is not list or ndarray. TypeError: If `feature_types` is not list or ndarray.
""" """
if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server")
if isinstance(edge_list, list): if isinstance(edge_list, list):
edge_list = np.array(edge_list, dtype=np.int32) edge_list = np.array(edge_list, dtype=np.int32)
return [ return [
t.as_array() for t in self._graph.get_edge_feature( t.as_array() for t in self._graph_data.get_edge_feature(
Tensor(edge_list), Tensor(edge_list),
feature_types)] feature_types)]
@ -252,7 +309,9 @@ class GraphData:
dict: Meta information of the graph. The key is node_type, edge_type, node_num, edge_num, dict: Meta information of the graph. The key is node_type, edge_type, node_num, edge_num,
node_feature_type and edge_feature_type. node_feature_type and edge_feature_type.
""" """
return self._graph.graph_info() if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server")
return self._graph_data.graph_info()
@check_gnn_random_walk @check_gnn_random_walk
def random_walk( def random_walk(
@ -285,5 +344,7 @@ class GraphData:
TypeError: If `target_nodes` is not list or ndarray. TypeError: If `target_nodes` is not list or ndarray.
TypeError: If `meta_path` is not list or ndarray. TypeError: If `meta_path` is not list or ndarray.
""" """
return self._graph.random_walk(target_nodes, meta_path, step_home_param, step_away_param, if self._working_mode == 'server':
default_node).as_array() raise Exception("This method is not supported when working mode is server")
return self._graph_data.random_walk(target_nodes, meta_path, step_home_param, step_away_param,
default_node).as_array()

View File

@ -18,6 +18,7 @@ Built-in validators.
""" """
import inspect as ins import inspect as ins
import os import os
import re
from functools import wraps from functools import wraps
import numpy as np import numpy as np
@ -912,16 +913,36 @@ def check_split(method):
return new_method return new_method
def check_hostname(hostname):
if len(hostname) > 255:
return False
if hostname[-1] == ".":
hostname = hostname[:-1] # strip exactly one dot from the right, if present
allowed = re.compile("(?!-)[A-Z\\d-]{1,63}(?<!-)$", re.IGNORECASE)
return all(allowed.match(x) for x in hostname.split("."))
def check_gnn_graphdata(method): def check_gnn_graphdata(method):
"""check the input arguments of graphdata.""" """check the input arguments of graphdata."""
@wraps(method) @wraps(method)
def new_method(self, *args, **kwargs): def new_method(self, *args, **kwargs):
[dataset_file, num_parallel_workers], _ = parse_user_args(method, *args, **kwargs) [dataset_file, num_parallel_workers, working_mode, hostname,
port, num_client, auto_shutdown], _ = parse_user_args(method, *args, **kwargs)
check_file(dataset_file) check_file(dataset_file)
if num_parallel_workers is not None: if num_parallel_workers is not None:
check_num_parallel_workers(num_parallel_workers) check_num_parallel_workers(num_parallel_workers)
type_check(hostname, (str,), "hostname")
if check_hostname(hostname) is False:
raise ValueError("The hostname is illegal")
type_check(working_mode, (str,), "working_mode")
if working_mode not in {'local', 'client', 'server'}:
raise ValueError("Invalid working mode")
type_check(port, (int,), "port")
check_value(port, (1024, 65535), "port")
type_check(num_client, (int,), "num_client")
check_value(num_client, (1, 255), "num_client")
type_check(auto_shutdown, (bool,), "auto_shutdown")
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
return new_method return new_method

View File

@ -15,6 +15,7 @@
""" """
User-defined API for MindRecord GNN writer. User-defined API for MindRecord GNN writer.
""" """
import numpy as np
social_data = [[348, 350], [348, 327], [348, 329], [348, 331], [348, 335], social_data = [[348, 350], [348, 327], [348, 329], [348, 331], [348, 335],
[348, 336], [348, 337], [348, 338], [348, 340], [348, 341], [348, 336], [348, 337], [348, 338], [348, 340], [348, 341],
[348, 342], [348, 343], [348, 344], [348, 345], [348, 346], [348, 342], [348, 343], [348, 344], [348, 345], [348, 346],
@ -29,7 +30,7 @@ social_data = [[348, 350], [348, 327], [348, 329], [348, 331], [348, 335],
[355, 352], [353, 350], [352, 349], [351, 349], [350, 349]] [355, 352], [353, 350], [352, 349], [351, 349], [350, 349]]
# profile: (num_features, feature_data_types, feature_shapes) # profile: (num_features, feature_data_types, feature_shapes)
node_profile = (0, [], []) node_profile = (2, ["int64", "int32"], [[-1], [-1]])
edge_profile = (0, [], []) edge_profile = (0, [], [])
@ -51,7 +52,9 @@ def yield_nodes(task_id=0):
node_list.sort() node_list.sort()
print(node_list) print(node_list)
for node_id in node_list: for node_id in node_list:
node = {'id': node_id, 'type': 1} node = {'id': node_id, 'type': 1,
'feature_1': np.ones((5,), dtype=np.int64),
'feature_2': np.ones((10,), dtype=np.int32)}
yield node yield node

View File

@ -22,6 +22,7 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "minddata/dataset/util/status.h" #include "minddata/dataset/util/status.h"
#include "minddata/dataset/engine/gnn/node.h" #include "minddata/dataset/engine/gnn/node.h"
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
#include "minddata/dataset/engine/gnn/graph_loader.h" #include "minddata/dataset/engine/gnn/graph_loader.h"
using namespace mindspore::dataset; using namespace mindspore::dataset;
@ -39,30 +40,9 @@ class MindDataTestGNNGraph : public UT::Common {
MindDataTestGNNGraph() = default; MindDataTestGNNGraph() = default;
}; };
TEST_F(MindDataTestGNNGraph, TestGraphLoader) {
std::string path = "data/mindrecord/testGraphData/testdata";
GraphLoader gl(path, 4);
EXPECT_TRUE(gl.InitAndLoad().IsOk());
NodeIdMap n_id_map;
EdgeIdMap e_id_map;
NodeTypeMap n_type_map;
EdgeTypeMap e_type_map;
NodeFeatureMap n_feature_map;
EdgeFeatureMap e_feature_map;
DefaultNodeFeatureMap default_node_feature_map;
DefaultEdgeFeatureMap default_edge_feature_map;
EXPECT_TRUE(gl.GetNodesAndEdges(&n_id_map, &e_id_map, &n_type_map, &e_type_map, &n_feature_map, &e_feature_map,
&default_node_feature_map, &default_edge_feature_map)
.IsOk());
EXPECT_EQ(n_id_map.size(), 20);
EXPECT_EQ(e_id_map.size(), 40);
EXPECT_EQ(n_type_map[2].size(), 10);
EXPECT_EQ(n_type_map[1].size(), 10);
}
TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
std::string path = "data/mindrecord/testGraphData/testdata"; std::string path = "data/mindrecord/testGraphData/testdata";
Graph graph(path, 1); GraphDataImpl graph(path, 1);
Status s = graph.Init(); Status s = graph.Init();
EXPECT_TRUE(s.IsOk()); EXPECT_TRUE(s.IsOk());
@ -103,7 +83,7 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) { TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) {
std::string path = "data/mindrecord/testGraphData/testdata"; std::string path = "data/mindrecord/testGraphData/testdata";
Graph graph(path, 1); GraphDataImpl graph(path, 1);
Status s = graph.Init(); Status s = graph.Init();
EXPECT_TRUE(s.IsOk()); EXPECT_TRUE(s.IsOk());
@ -194,7 +174,7 @@ TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) {
TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) { TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) {
std::string path = "data/mindrecord/testGraphData/testdata"; std::string path = "data/mindrecord/testGraphData/testdata";
Graph graph(path, 1); GraphDataImpl graph(path, 1);
Status s = graph.Init(); Status s = graph.Init();
EXPECT_TRUE(s.IsOk()); EXPECT_TRUE(s.IsOk());
@ -237,7 +217,7 @@ TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) {
TEST_F(MindDataTestGNNGraph, TestRandomWalk) { TEST_F(MindDataTestGNNGraph, TestRandomWalk) {
std::string path = "data/mindrecord/testGraphData/sns"; std::string path = "data/mindrecord/testGraphData/sns";
Graph graph(path, 1); GraphDataImpl graph(path, 1);
Status s = graph.Init(); Status s = graph.Init();
EXPECT_TRUE(s.IsOk()); EXPECT_TRUE(s.IsOk());
@ -263,7 +243,7 @@ TEST_F(MindDataTestGNNGraph, TestRandomWalk) {
TEST_F(MindDataTestGNNGraph, TestRandomWalkDefaults) { TEST_F(MindDataTestGNNGraph, TestRandomWalkDefaults) {
std::string path = "data/mindrecord/testGraphData/sns"; std::string path = "data/mindrecord/testGraphData/sns";
Graph graph(path, 1); GraphDataImpl graph(path, 1);
Status s = graph.Init(); Status s = graph.Init();
EXPECT_TRUE(s.IsOk()); EXPECT_TRUE(s.IsOk());

View File

@ -0,0 +1,125 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import random
import time
from multiprocessing import Process
import numpy as np
import mindspore.dataset as ds
from mindspore import log as logger
DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
def graphdata_startserver():
"""
start graphdata server
"""
logger.info('test start server.\n')
ds.GraphData(DATASET_FILE, 1, 'server')
class RandomBatchedSampler(ds.Sampler):
# RandomBatchedSampler generate random sequence without replacement in a batched manner
def __init__(self, index_range, num_edges_per_sample):
super().__init__()
self.index_range = index_range
self.num_edges_per_sample = num_edges_per_sample
def __iter__(self):
indices = [i+1 for i in range(self.index_range)]
# Reset random seed here if necessary
# random.seed(0)
random.shuffle(indices)
for i in range(0, self.index_range, self.num_edges_per_sample):
# Drop reminder
if i + self.num_edges_per_sample <= self.index_range:
yield indices[i: i + self.num_edges_per_sample]
class GNNGraphDataset():
def __init__(self, g, batch_num):
self.g = g
self.batch_num = batch_num
def __len__(self):
# Total sample size of GNN dataset
# In this case, the size should be total_num_edges/num_edges_per_sample
return self.g.graph_info()['edge_num'][0] // self.batch_num
def __getitem__(self, index):
# index will be a list of indices yielded from RandomBatchedSampler
# Fetch edges/nodes/samples/features based on indices
nodes = self.g.get_nodes_from_edges(index.astype(np.int32))
nodes = nodes[:, 0]
neg_nodes = self.g.get_neg_sampled_neighbors(
node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1)
nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[
2, 2], neighbor_types=[2, 1])
neg_nodes_neighbors = self.g.get_sampled_neighbors(
node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], neighbor_types=[2, 2])
nodes_neighbors_features = self.g.get_node_feature(
node_list=nodes_neighbors, feature_types=[2, 3])
neg_neighbors_features = self.g.get_node_feature(
node_list=neg_nodes_neighbors, feature_types=[2, 3])
return nodes_neighbors, neg_nodes_neighbors, nodes_neighbors_features[0], neg_neighbors_features[1]
def test_graphdata_distributed():
"""
Test distributed
"""
logger.info('test distributed.\n')
p1 = Process(target=graphdata_startserver)
p1.start()
time.sleep(2)
g = ds.GraphData(DATASET_FILE, 1, 'client')
nodes = g.get_all_nodes(1)
assert nodes.tolist() == [101, 102, 103, 104, 105, 106, 107, 108, 109, 110]
row_tensor = g.get_node_feature(nodes.tolist(), [1, 2, 3])
assert row_tensor[0].tolist() == [[0, 1, 0, 0, 0], [1, 0, 0, 0, 1], [0, 0, 1, 1, 0], [0, 0, 0, 0, 0],
[1, 1, 0, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0], [0, 0, 0, 1, 1],
[0, 1, 1, 0, 0], [0, 1, 0, 1, 0]]
assert row_tensor[2].tolist() == [1, 2, 3, 1, 4, 3, 5, 3, 5, 4]
edges = g.get_all_edges(0)
assert edges.tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40]
features = g.get_edge_feature(edges, [1, 2])
assert features[0].tolist() == [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]
batch_num = 2
edge_num = g.graph_info()['edge_num'][0]
out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"]
dataset = ds.GeneratorDataset(source=GNNGraphDataset(g, batch_num), column_names=out_column_names,
sampler=RandomBatchedSampler(edge_num, batch_num), num_parallel_workers=4,
python_multiprocessing=False)
dataset = dataset.repeat(2)
itr = dataset.create_dict_iterator()
i = 0
for data in itr:
assert data['neighbors'].shape == (2, 7)
assert data['neg_neighbors'].shape == (6, 7)
assert data['neighbors_features'].shape == (2, 7)
assert data['neg_neighbors_features'].shape == (6, 7)
i += 1
assert i == 40
if __name__ == '__main__':
test_graphdata_distributed()