From 8ee4d8e92d02ff255ab42d521075b974adb105a0 Mon Sep 17 00:00:00 2001 From: heleiwang Date: Fri, 14 Aug 2020 14:56:27 +0800 Subject: [PATCH] Gnn data processing supports distributed scenarios --- cmake/mind_expression.cmake | 7 + .../ccsrc/minddata/dataset/CMakeLists.txt | 5 + .../bindings/dataset/engine/gnn/bindings.cc | 56 +- .../dataset/engine/gnn/CMakeLists.txt | 26 +- .../minddata/dataset/engine/gnn/feature.cc | 3 +- .../minddata/dataset/engine/gnn/feature.h | 3 +- .../dataset/engine/gnn/gnn_graph_data.proto | 103 +++ .../dataset/engine/gnn/gnn_tensor.proto | 42 ++ .../minddata/dataset/engine/gnn/graph_data.h | 134 ++++ .../dataset/engine/gnn/graph_data_client.cc | 589 ++++++++++++++++++ .../dataset/engine/gnn/graph_data_client.h | 185 ++++++ .../gnn/{graph.cc => graph_data_impl.cc} | 189 ++++-- .../engine/gnn/{graph.h => graph_data_impl.h} | 94 +-- .../dataset/engine/gnn/graph_data_server.cc | 133 ++++ .../dataset/engine/gnn/graph_data_server.h | 196 ++++++ .../engine/gnn/graph_data_service_impl.cc | 299 +++++++++ .../engine/gnn/graph_data_service_impl.h | 70 +++ .../engine/gnn/graph_feature_parser.cc | 106 ++++ .../dataset/engine/gnn/graph_feature_parser.h | 67 ++ .../dataset/engine/gnn/graph_loader.cc | 188 +++--- .../dataset/engine/gnn/graph_loader.h | 38 +- .../dataset/engine/gnn/graph_shared_memory.cc | 134 ++++ .../dataset/engine/gnn/graph_shared_memory.h | 72 +++ .../dataset/engine/gnn/grpc_async_server.cc | 82 +++ .../dataset/engine/gnn/grpc_async_server.h | 59 ++ .../minddata/dataset/engine/gnn/local_edge.cc | 1 + .../minddata/dataset/engine/gnn/local_edge.h | 2 +- .../minddata/dataset/engine/gnn/local_node.h | 2 +- .../ccsrc/minddata/dataset/engine/gnn/node.h | 2 +- .../dataset/engine/gnn/tensor_proto.cc | 84 +++ .../dataset/engine/gnn/tensor_proto.h | 36 ++ .../mindrecord/include/shard_column.h | 12 +- .../mindrecord/include/shard_header.h | 10 + .../mindrecord/include/shard_reader.h | 4 + .../mindrecord/include/shard_writer.h | 1 + .../minddata/mindrecord/io/shard_reader.cc | 19 +- .../minddata/mindrecord/io/shard_writer.cc | 18 +- .../minddata/mindrecord/meta/shard_column.cc | 20 +- .../minddata/mindrecord/meta/shard_header.cc | 9 +- mindspore/dataset/engine/datasets.py | 164 +++-- mindspore/dataset/engine/graphdata.py | 89 ++- mindspore/dataset/engine/validators.py | 25 +- .../utils/graph_to_mindrecord/sns/mr_api.py | 7 +- tests/ut/cpp/dataset/gnn_graph_test.cc | 32 +- tests/ut/data/mindrecord/testGraphData/sns | Bin 58572 -> 58572 bytes tests/ut/data/mindrecord/testGraphData/sns.db | Bin 24576 -> 24576 bytes .../ut/data/mindrecord/testGraphData/testdata | Bin 52682 -> 52682 bytes .../dataset/test_graphdata_distributed.py | 125 ++++ 48 files changed, 3202 insertions(+), 340 deletions(-) create mode 100644 mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_graph_data.proto create mode 100644 mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_tensor.proto create mode 100644 mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data.h create mode 100644 mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.cc create mode 100644 mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.h rename mindspore/ccsrc/minddata/dataset/engine/gnn/{graph.cc => graph_data_impl.cc} (76%) rename mindspore/ccsrc/minddata/dataset/engine/gnn/{graph.h => graph_data_impl.h} (81%) create mode 100644 mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_server.cc create mode 100644 mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_server.h create mode 100644 mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.cc create mode 100644 mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.h create mode 100644 mindspore/ccsrc/minddata/dataset/engine/gnn/graph_feature_parser.cc create mode 100644 mindspore/ccsrc/minddata/dataset/engine/gnn/graph_feature_parser.h create mode 100644 mindspore/ccsrc/minddata/dataset/engine/gnn/graph_shared_memory.cc create mode 100644 mindspore/ccsrc/minddata/dataset/engine/gnn/graph_shared_memory.h create mode 100644 mindspore/ccsrc/minddata/dataset/engine/gnn/grpc_async_server.cc create mode 100644 mindspore/ccsrc/minddata/dataset/engine/gnn/grpc_async_server.h create mode 100644 mindspore/ccsrc/minddata/dataset/engine/gnn/tensor_proto.cc create mode 100644 mindspore/ccsrc/minddata/dataset/engine/gnn/tensor_proto.h create mode 100644 tests/ut/python/dataset/test_graphdata_distributed.py diff --git a/cmake/mind_expression.cmake b/cmake/mind_expression.cmake index 8e1e9ce553c..f6fca0891eb 100644 --- a/cmake/mind_expression.cmake +++ b/cmake/mind_expression.cmake @@ -15,7 +15,14 @@ include(${CMAKE_SOURCE_DIR}/cmake/external_libs/json.cmake) include(${CMAKE_SOURCE_DIR}/cmake/dependency_securec.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake) +SET(MS_BUILD_GRPC 0) if (ENABLE_DEBUGGER OR ENABLE_SERVING OR ENABLE_TESTCASES) + SET(MS_BUILD_GRPC 1) +endif() +if (ENABLE_MINDDATA AND NOT CMAKE_SYSTEM_NAME MATCHES "Windows") + SET(MS_BUILD_GRPC 1) +endif() +if ("${MS_BUILD_GRPC}") # build dependencies of gRPC include(${CMAKE_SOURCE_DIR}/cmake/external_libs/absl.cmake) include(${CMAKE_SOURCE_DIR}/cmake/external_libs/c-ares.cmake) diff --git a/mindspore/ccsrc/minddata/dataset/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/CMakeLists.txt index 7fa5a69038c..678c20ba102 100644 --- a/mindspore/ccsrc/minddata/dataset/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/CMakeLists.txt @@ -83,6 +83,7 @@ endif() if (ENABLE_TDTQUE) add_dependencies(engine-tdt core) endif () + ################### Create _c_dataengine Library ###################### set(submodules $ @@ -182,3 +183,7 @@ else() set_target_properties(_c_dataengine PROPERTIES MACOSX_RPATH ON) endif () endif() + +if (NOT CMAKE_SYSTEM_NAME MATCHES "Windows") + target_link_libraries(_c_dataengine PRIVATE mindspore::grpc++) +endif() \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/gnn/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/gnn/bindings.cc index 18dcfb470a0..936ba2804e9 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/gnn/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/gnn/bindings.cc @@ -18,83 +18,103 @@ #include "pybind11/stl_bind.h" #include "minddata/dataset/api/python/pybind_register.h" - -#include "minddata/dataset/engine/gnn/graph.h" +#include "minddata/dataset/engine/gnn/graph_data_client.h" +#include "minddata/dataset/engine/gnn/graph_data_impl.h" +#include "minddata/dataset/engine/gnn/graph_data_server.h" namespace mindspore { namespace dataset { PYBIND_REGISTER( Graph, 0, ([](const py::module *m) { - (void)py::class_>(*m, "Graph") - .def(py::init([](std::string dataset_file, int32_t num_workers) { - std::shared_ptr g_out = std::make_shared(dataset_file, num_workers); - THROW_IF_ERROR(g_out->Init()); - return g_out; + (void)py::class_>(*m, "GraphDataClient") + .def(py::init([](const std::string &dataset_file, int32_t num_workers, const std::string &working_mode, + const std::string &hostname, int32_t port) { + std::shared_ptr out; + if (working_mode == "local") { + out = std::make_shared(dataset_file, num_workers); + } else if (working_mode == "client") { + out = std::make_shared(dataset_file, hostname, port); + } + THROW_IF_ERROR(out->Init()); + return out; })) .def("get_all_nodes", - [](gnn::Graph &g, gnn::NodeType node_type) { + [](gnn::GraphData &g, gnn::NodeType node_type) { std::shared_ptr out; THROW_IF_ERROR(g.GetAllNodes(node_type, &out)); return out; }) .def("get_all_edges", - [](gnn::Graph &g, gnn::EdgeType edge_type) { + [](gnn::GraphData &g, gnn::EdgeType edge_type) { std::shared_ptr out; THROW_IF_ERROR(g.GetAllEdges(edge_type, &out)); return out; }) .def("get_nodes_from_edges", - [](gnn::Graph &g, std::vector edge_list) { + [](gnn::GraphData &g, std::vector edge_list) { std::shared_ptr out; THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out)); return out; }) .def("get_all_neighbors", - [](gnn::Graph &g, std::vector node_list, gnn::NodeType neighbor_type) { + [](gnn::GraphData &g, std::vector node_list, gnn::NodeType neighbor_type) { std::shared_ptr out; THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out)); return out; }) .def("get_sampled_neighbors", - [](gnn::Graph &g, std::vector node_list, std::vector neighbor_nums, + [](gnn::GraphData &g, std::vector node_list, std::vector neighbor_nums, std::vector neighbor_types) { std::shared_ptr out; THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out)); return out; }) .def("get_neg_sampled_neighbors", - [](gnn::Graph &g, std::vector node_list, gnn::NodeIdType neighbor_num, + [](gnn::GraphData &g, std::vector node_list, gnn::NodeIdType neighbor_num, gnn::NodeType neg_neighbor_type) { std::shared_ptr out; THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out)); return out; }) .def("get_node_feature", - [](gnn::Graph &g, std::shared_ptr node_list, std::vector feature_types) { + [](gnn::GraphData &g, std::shared_ptr node_list, std::vector feature_types) { TensorRow out; THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out)); return out.getRow(); }) .def("get_edge_feature", - [](gnn::Graph &g, std::shared_ptr edge_list, std::vector feature_types) { + [](gnn::GraphData &g, std::shared_ptr edge_list, std::vector feature_types) { TensorRow out; THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out)); return out.getRow(); }) .def("graph_info", - [](gnn::Graph &g) { + [](gnn::GraphData &g) { py::dict out; THROW_IF_ERROR(g.GraphInfo(&out)); return out; }) .def("random_walk", - [](gnn::Graph &g, std::vector node_list, std::vector meta_path, + [](gnn::GraphData &g, std::vector node_list, std::vector meta_path, float step_home_param, float step_away_param, gnn::NodeIdType default_node) { std::shared_ptr out; THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out)); return out; - }); + }) + .def("stop", [](gnn::GraphData &g) { THROW_IF_ERROR(g.Stop()); }); + + (void)py::class_>(*m, "GraphDataServer") + .def(py::init([](const std::string &dataset_file, int32_t num_workers, const std::string &hostname, int32_t port, + int32_t client_num, bool auto_shutdown) { + std::shared_ptr out; + out = + std::make_shared(dataset_file, num_workers, hostname, port, client_num, auto_shutdown); + THROW_IF_ERROR(out->Init()); + return out; + })) + .def("stop", [](gnn::GraphDataServer &g) { THROW_IF_ERROR(g.Stop()); }) + .def("is_stoped", [](gnn::GraphDataServer &g) { return g.IsStoped(); }); })); } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/engine/gnn/CMakeLists.txt index 401fce6d118..52f0707310b 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/CMakeLists.txt @@ -1,9 +1,29 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) -add_library(engine-gnn OBJECT - graph.cc +set(DATASET_ENGINE_GNN_SRC_FILES + graph_data_impl.cc + graph_data_client.cc + graph_data_server.cc graph_loader.cc + graph_feature_parser.cc local_node.cc local_edge.cc feature.cc - ) +) + +if (CMAKE_SYSTEM_NAME MATCHES "Windows") + add_library(engine-gnn OBJECT ${DATASET_ENGINE_GNN_SRC_FILES}) +else() + set(DATASET_ENGINE_GNN_SRC_FILES + ${DATASET_ENGINE_GNN_SRC_FILES} + tensor_proto.cc + grpc_async_server.cc + graph_data_service_impl.cc + graph_shared_memory.cc) + + ms_protobuf_generate(TENSOR_PROTO_SRCS TENSOR_PROTO_HDRS "gnn_tensor.proto") + ms_grpc_generate(GNN_PROTO_SRCS GNN_PROTO_HDRS "gnn_graph_data.proto") + + add_library(engine-gnn OBJECT ${DATASET_ENGINE_GNN_SRC_FILES} ${TENSOR_PROTO_SRCS} ${GNN_PROTO_SRCS}) + add_dependencies(engine-gnn mindspore::protobuf) +endif() diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/feature.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/feature.cc index dba4a6fa609..073415242b7 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/feature.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/feature.cc @@ -19,7 +19,8 @@ namespace mindspore { namespace dataset { namespace gnn { -Feature::Feature(FeatureType type_name, std::shared_ptr value) : type_name_(type_name), value_(value) {} +Feature::Feature(FeatureType type_name, std::shared_ptr value, bool is_shared_memory) + : type_name_(type_name), value_(value), is_shared_memory_(is_shared_memory) {} } // namespace gnn } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/feature.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/feature.h index 0151ada706b..aae1716f1af 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/feature.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/feature.h @@ -31,7 +31,7 @@ class Feature { // Constructor // @param FeatureType type_name - feature type // @param std::shared_ptr value - feature value - Feature(FeatureType type_name, std::shared_ptr value); + Feature(FeatureType type_name, std::shared_ptr value, bool is_shared_memory = false); ~Feature() = default; @@ -45,6 +45,7 @@ class Feature { private: FeatureType type_name_; std::shared_ptr value_; + bool is_shared_memory_; }; } // namespace gnn } // namespace dataset diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_graph_data.proto b/mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_graph_data.proto new file mode 100644 index 00000000000..1342d047cc5 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_graph_data.proto @@ -0,0 +1,103 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package mindspore.dataset; + +import "gnn_tensor.proto"; + +message GnnClientRegisterRequestPb { + int32 pid = 1; +} + +message GnnFeatureInfoPb { + int32 type = 1; + TensorPb feature = 2; +} + +message GnnClientRegisterResponsePb { + string error_msg = 1; + string data_schema = 2; + int64 shared_memory_key = 3; + int64 shared_memory_size = 4; + repeated GnnFeatureInfoPb default_node_feature = 5; + repeated GnnFeatureInfoPb default_edge_feature = 6; +} + +message GnnClientUnRegisterRequestPb { + int32 pid = 1; +} + +message GnnClientUnRegisterResponsePb { + string error_msg = 1; +} + +enum GnnOpName { + GET_ALL_NODES = 0; + GET_ALL_EDGES = 1; + GET_NODES_FROM_EDGES = 2; + GET_ALL_NEIGHBORS = 3; + GET_SAMPLED_NEIGHBORS = 4; + GET_NEG_SAMPLED_NEIGHBORS = 5; + RANDOM_WALK = 6; + GET_NODE_FEATURE = 7; + GET_EDGE_FEATURE = 8; +} + +message GnnRandomWalkPb { + float p = 1; + float q = 2; + int32 default_id = 3; +} + +message GnnGraphDataRequestPb { + GnnOpName op_name = 1; + repeated int32 id = 2; // node id or edge id + repeated int32 type = 3; //node type or edge type or neighbor type or feature type + repeated int32 number = 4; // samples number + TensorPb id_tensor = 5; // input ids ,node id or edge id + GnnRandomWalkPb random_walk = 6; +} + +message GnnGraphDataResponsePb { + string error_msg = 1; + repeated TensorPb result_data = 2; +} + +message GnnMetaInfoRequestPb { + +} + +message GnnNodeEdgeInfoPb { + int32 type = 1; + int32 num = 2; +} + +message GnnMetaInfoResponsePb { + string error_msg = 1; + repeated GnnNodeEdgeInfoPb node_info = 2; + repeated GnnNodeEdgeInfoPb edge_info = 3; + repeated int32 node_feature_type = 4; + repeated int32 edge_feature_type = 5; +} + +service GnnGraphData { + rpc ClientRegister(GnnClientRegisterRequestPb) returns (GnnClientRegisterResponsePb); + rpc ClientUnRegister(GnnClientUnRegisterRequestPb) returns (GnnClientUnRegisterResponsePb); + rpc GetGraphData(GnnGraphDataRequestPb) returns (GnnGraphDataResponsePb); + rpc GetMetaInfo(GnnMetaInfoRequestPb) returns (GnnMetaInfoResponsePb); +} diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_tensor.proto b/mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_tensor.proto new file mode 100644 index 00000000000..2dfa6438ad2 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_tensor.proto @@ -0,0 +1,42 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +package mindspore.dataset; + +enum DataTypePb { + DE_PB_UNKNOWN = 0; + DE_PB_BOOL = 1; + DE_PB_INT8 = 2; + DE_PB_UINT8 = 3; + DE_PB_INT16 = 4; + DE_PB_UINT16 = 5; + DE_PB_INT32 = 6; + DE_PB_UINT32 = 7; + DE_PB_INT64 = 8; + DE_PB_UINT64 = 9; + DE_PB_FLOAT16 = 10; + DE_PB_FLOAT32 = 11; + DE_PB_FLOAT64 = 12; + DE_PB_STRING = 13; +} + +message TensorPb { + repeated int64 dims = 1; // tensor shape info + DataTypePb tensor_type = 2; // tensor content data type + bytes data = 3; // tensor data +} diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data.h new file mode 100644 index 00000000000..5e3dc1d4058 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data.h @@ -0,0 +1,134 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/core/tensor_row.h" +#include "minddata/dataset/engine/gnn/feature.h" +#include "minddata/dataset/engine/gnn/node.h" +#include "minddata/dataset/engine/gnn/edge.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +struct MetaInfo { + std::vector node_type; + std::vector edge_type; + std::map node_num; + std::map edge_num; + std::vector node_feature_type; + std::vector edge_feature_type; +}; + +class GraphData { + public: + // Get all nodes from the graph. + // @param NodeType node_type - type of node + // @param std::shared_ptr *out - Returned nodes id + // @return Status - The error code return + virtual Status GetAllNodes(NodeType node_type, std::shared_ptr *out) = 0; + + // Get all edges from the graph. + // @param NodeType edge_type - type of edge + // @param std::shared_ptr *out - Returned edge ids + // @return Status - The error code return + virtual Status GetAllEdges(EdgeType edge_type, std::shared_ptr *out) = 0; + + // Get the node id from the edge. + // @param std::vector edge_list - List of edges + // @param std::shared_ptr *out - Returned node ids + // @return Status - The error code return + virtual Status GetNodesFromEdges(const std::vector &edge_list, std::shared_ptr *out) = 0; + + // All neighbors of the acquisition node. + // @param std::vector node_list - List of nodes + // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported + // @param std::shared_ptr *out - Returned neighbor's id. Because the number of neighbors at different nodes is + // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors + // is not enough, fill in tensor as -1. + // @return Status - The error code return + virtual Status GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, + std::shared_ptr *out) = 0; + + // Get sampled neighbors. + // @param std::vector node_list - List of nodes + // @param std::vector neighbor_nums - Number of neighbors sampled per hop + // @param std::vector neighbor_types - Neighbor type sampled per hop + // @param std::shared_ptr *out - Returned neighbor's id. + // @return Status - The error code return + virtual Status GetSampledNeighbors(const std::vector &node_list, + const std::vector &neighbor_nums, + const std::vector &neighbor_types, std::shared_ptr *out) = 0; + + // Get negative sampled neighbors. + // @param std::vector node_list - List of nodes + // @param NodeIdType samples_num - Number of neighbors sampled + // @param NodeType neg_neighbor_type - The type of negative neighbor. + // @param std::shared_ptr *out - Returned negative neighbor's id. + // @return Status - The error code return + virtual Status GetNegSampledNeighbors(const std::vector &node_list, NodeIdType samples_num, + NodeType neg_neighbor_type, std::shared_ptr *out) = 0; + + // Node2vec random walk. + // @param std::vector node_list - List of nodes + // @param std::vector meta_path - node type of each step + // @param float step_home_param - return hyper parameter in node2vec algorithm + // @param float step_away_param - inout hyper parameter in node2vec algorithm + // @param NodeIdType default_node - default node id + // @param std::shared_ptr *out - Returned nodes id in walk path + // @return Status - The error code return + virtual Status RandomWalk(const std::vector &node_list, const std::vector &meta_path, + float step_home_param, float step_away_param, NodeIdType default_node, + std::shared_ptr *out) = 0; + + // Get the feature of a node + // @param std::shared_ptr nodes - List of nodes + // @param std::vector feature_types - Types of features, An error will be reported if the feature type + // does not exist. + // @param TensorRow *out - Returned features + // @return Status - The error code return + virtual Status GetNodeFeature(const std::shared_ptr &nodes, const std::vector &feature_types, + TensorRow *out) = 0; + + // Get the feature of a edge + // @param std::shared_ptr edges - List of edges + // @param std::vector feature_types - Types of features, An error will be reported if the feature type + // does not exist. + // @param Tensor *out - Returned features + // @return Status - The error code return + virtual Status GetEdgeFeature(const std::shared_ptr &edges, const std::vector &feature_types, + TensorRow *out) = 0; + + // Return meta information to python layer + virtual Status GraphInfo(py::dict *out) = 0; + + virtual Status Init() = 0; + + virtual Status Stop() = 0; +}; +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.cc new file mode 100644 index 00000000000..6fdde154268 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.cc @@ -0,0 +1,589 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/gnn/graph_data_client.h" + +#include +#include +#include + +#if !defined(_WIN32) && !defined(_WIN64) +#include "grpcpp/grpcpp.h" +#endif + +#include "minddata/dataset/core/data_type.h" +#if !defined(_WIN32) && !defined(_WIN64) +#include "minddata/dataset/engine/gnn/tensor_proto.h" +#endif + +namespace mindspore { +namespace dataset { +namespace gnn { + +GraphDataClient::GraphDataClient(const std::string &dataset_file, const std::string &hostname, int32_t port) + : dataset_file_(dataset_file), + host_(hostname), + port_(port), + pid_(0), +#if !defined(_WIN32) && !defined(_WIN64) + shared_memory_key_(-1), + shared_memory_size_(0), + graph_feature_parser_(nullptr), + graph_shared_memory_(nullptr), +#endif + registered_(false) { +} + +GraphDataClient::~GraphDataClient() { (void)Stop(); } + +Status GraphDataClient::Init() { +#if defined(_WIN32) || defined(_WIN64) + RETURN_STATUS_UNEXPECTED("Graph data client is not supported in Windows OS"); +#else + if (!registered_) { + std::string server_address; + server_address = host_ + ":" + std::to_string(port_); + MS_LOG(INFO) << "Graph data client starting. address:" << server_address; + pid_ = getpid(); + grpc::ChannelArguments args; + args.SetMaxReceiveMessageSize(-1); + std::shared_ptr channel = + grpc::CreateCustomChannel(server_address, grpc::InsecureChannelCredentials(), args); + stub_ = GnnGraphData::NewStub(channel); + Status status = RegisterToServer(); + while (status.ToString().find("Initializing") != std::string::npos) { + MS_LOG(INFO) << "Graph data server is initializing, please wait."; + std::this_thread::sleep_for(std::chrono::milliseconds(2000)); + status = RegisterToServer(); + } + RETURN_IF_NOT_OK(status); + MS_LOG(INFO) << "Graph data client successfully registered with server " << server_address; + } + RETURN_IF_NOT_OK(InitFeatureParser()); + return Status::OK(); +#endif +} + +Status GraphDataClient::Stop() { +#if !defined(_WIN32) && !defined(_WIN64) + if (registered_) { + UnRegisterToServer(); + } +#endif + return Status::OK(); +} + +Status GraphDataClient::GetAllNodes(NodeType node_type, std::shared_ptr *out) { +#if !defined(_WIN32) && !defined(_WIN64) + GnnGraphDataRequestPb request; + GnnGraphDataResponsePb response; + request.set_op_name(GET_ALL_NODES); + request.add_type(static_cast(node_type)); + RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); +#endif + return Status::OK(); +} + +Status GraphDataClient::GetAllEdges(EdgeType edge_type, std::shared_ptr *out) { +#if !defined(_WIN32) && !defined(_WIN64) + GnnGraphDataRequestPb request; + GnnGraphDataResponsePb response; + request.set_op_name(GET_ALL_EDGES); + request.add_type(static_cast(edge_type)); + RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); +#endif + return Status::OK(); +} + +Status GraphDataClient::GetNodesFromEdges(const std::vector &edge_list, std::shared_ptr *out) { +#if !defined(_WIN32) && !defined(_WIN64) + GnnGraphDataRequestPb request; + GnnGraphDataResponsePb response; + request.set_op_name(GET_NODES_FROM_EDGES); + for (const auto &edge_id : edge_list) { + request.add_id(static_cast(edge_id)); + } + RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); +#endif + return Status::OK(); +} + +Status GraphDataClient::GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, + std::shared_ptr *out) { +#if !defined(_WIN32) && !defined(_WIN64) + GnnGraphDataRequestPb request; + GnnGraphDataResponsePb response; + request.set_op_name(GET_ALL_NEIGHBORS); + for (const auto &node_id : node_list) { + request.add_id(static_cast(node_id)); + } + request.add_type(static_cast(neighbor_type)); + RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); +#endif + return Status::OK(); +} + +Status GraphDataClient::GetSampledNeighbors(const std::vector &node_list, + const std::vector &neighbor_nums, + const std::vector &neighbor_types, std::shared_ptr *out) { +#if !defined(_WIN32) && !defined(_WIN64) + GnnGraphDataRequestPb request; + GnnGraphDataResponsePb response; + request.set_op_name(GET_SAMPLED_NEIGHBORS); + for (const auto &node_id : node_list) { + request.add_id(static_cast(node_id)); + } + for (const auto &num : neighbor_nums) { + request.add_number(static_cast(num)); + } + for (const auto &type : neighbor_types) { + request.add_type(static_cast(type)); + } + RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); +#endif + return Status::OK(); +} + +Status GraphDataClient::GetNegSampledNeighbors(const std::vector &node_list, NodeIdType samples_num, + NodeType neg_neighbor_type, std::shared_ptr *out) { +#if !defined(_WIN32) && !defined(_WIN64) + GnnGraphDataRequestPb request; + GnnGraphDataResponsePb response; + request.set_op_name(GET_NEG_SAMPLED_NEIGHBORS); + for (const auto &node_id : node_list) { + request.add_id(static_cast(node_id)); + } + request.add_number(static_cast(samples_num)); + request.add_type(static_cast(neg_neighbor_type)); + RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); +#endif + return Status::OK(); +} + +Status GraphDataClient::GraphDataClient::RandomWalk(const std::vector &node_list, + const std::vector &meta_path, float step_home_param, + float step_away_param, NodeIdType default_node, + std::shared_ptr *out) { +#if !defined(_WIN32) && !defined(_WIN64) + GnnGraphDataRequestPb request; + GnnGraphDataResponsePb response; + request.set_op_name(RANDOM_WALK); + for (const auto &node_id : node_list) { + request.add_id(static_cast(node_id)); + } + for (const auto &type : meta_path) { + request.add_type(static_cast(type)); + } + auto walk_param = request.mutable_random_walk(); + walk_param->set_p(step_home_param); + walk_param->set_q(step_away_param); + walk_param->set_default_id(static_cast(default_node)); + RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); +#endif + return Status::OK(); +} + +Status GraphDataClient::GetNodeFeature(const std::shared_ptr &nodes, + const std::vector &feature_types, TensorRow *out) { +#if !defined(_WIN32) && !defined(_WIN64) + if (!nodes || nodes->Size() == 0) { + RETURN_STATUS_UNEXPECTED("Input nodes is empty"); + } + CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Input feature_types is empty"); + + GnnGraphDataRequestPb request; + GnnGraphDataResponsePb response; + request.set_op_name(GET_NODE_FEATURE); + for (const auto &type : feature_types) { + request.add_type(static_cast(type)); + } + RETURN_IF_NOT_OK(TensorToPb(nodes, request.mutable_id_tensor())); + RETURN_IF_NOT_OK(GetGraphData(request, &response)); + CHECK_FAIL_RETURN_UNEXPECTED(feature_types.size() == response.result_data().size(), + "The number of feature types returned by the server is wrong"); + if (response.result_data().size() > 0) { + size_t i = 0; + for (const auto &result : response.result_data()) { + std::shared_ptr tensor; + RETURN_IF_NOT_OK(PbToTensor(&result, &tensor)); + std::shared_ptr fea_tensor; + RETURN_IF_NOT_OK(ParseNodeFeatureFromMemory(nodes, feature_types[i], tensor, &fea_tensor)); + out->emplace_back(std::move(fea_tensor)); + ++i; + } + } else { + RETURN_STATUS_UNEXPECTED("RPC failed: The number of returned tensor is abnormal"); + } +#endif + return Status::OK(); +} + +Status GraphDataClient::GetEdgeFeature(const std::shared_ptr &edges, + const std::vector &feature_types, TensorRow *out) { +#if !defined(_WIN32) && !defined(_WIN64) + if (!edges || edges->Size() == 0) { + RETURN_STATUS_UNEXPECTED("Input edges is empty"); + } + CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Input feature_types is empty"); + + GnnGraphDataRequestPb request; + GnnGraphDataResponsePb response; + request.set_op_name(GET_EDGE_FEATURE); + for (const auto &type : feature_types) { + request.add_type(static_cast(type)); + } + RETURN_IF_NOT_OK(TensorToPb(edges, request.mutable_id_tensor())); + RETURN_IF_NOT_OK(GetGraphData(request, &response)); + CHECK_FAIL_RETURN_UNEXPECTED(feature_types.size() == response.result_data().size(), + "The number of feature types returned by the server is wrong"); + if (response.result_data().size() > 0) { + size_t i = 0; + for (const auto &result : response.result_data()) { + std::shared_ptr tensor; + RETURN_IF_NOT_OK(PbToTensor(&result, &tensor)); + std::shared_ptr fea_tensor; + RETURN_IF_NOT_OK(ParseEdgeFeatureFromMemory(edges, feature_types[i], tensor, &fea_tensor)); + out->emplace_back(std::move(fea_tensor)); + ++i; + } + } else { + RETURN_STATUS_UNEXPECTED("RPC failed: The number of returned tensor is abnormal"); + } +#endif + return Status::OK(); +} + +Status GraphDataClient::GraphInfo(py::dict *out) { +#if !defined(_WIN32) && !defined(_WIN64) + RETURN_IF_NOT_OK(CheckPid()); + void *tag; + bool ok; + grpc::Status status; + grpc::ClientContext ctx; + grpc::CompletionQueue cq; + GnnMetaInfoRequestPb request; + GnnMetaInfoResponsePb response; + // One minute timeout + auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60); + ctx.set_deadline(deadline); + std::unique_ptr> rpc( + stub_->PrepareAsyncGetMetaInfo(&ctx, request, &cq)); + rpc->StartCall(); + rpc->Finish(&response, &status, &response); + + { + py::gil_scoped_release gil_release; + auto success = cq.Next(&tag, &ok); + CHECK_FAIL_RETURN_UNEXPECTED(success, "Expect successful"); + CHECK_FAIL_RETURN_UNEXPECTED(tag == &response, "Expect the same tag"); + CHECK_FAIL_RETURN_UNEXPECTED(ok, "Expect successful"); + } + + if (status.ok()) { + if (response.error_msg() != "Success") { + RETURN_STATUS_UNEXPECTED(response.error_msg()); + } else { + MetaInfo meta_info; + for (const auto &node : response.node_info()) { + meta_info.node_type.emplace_back(static_cast(node.type())); + meta_info.node_num[static_cast(node.type())] = static_cast(node.num()); + } + for (const auto &edge : response.edge_info()) { + meta_info.edge_type.emplace_back(static_cast(edge.type())); + meta_info.edge_num[static_cast(edge.type())] = static_cast(edge.num()); + } + for (const auto &feature_type : response.node_feature_type()) { + meta_info.node_feature_type.emplace_back(static_cast(feature_type)); + } + for (const auto &feature_type : response.edge_feature_type()) { + meta_info.edge_feature_type.emplace_back(static_cast(feature_type)); + } + (*out)["node_type"] = py::cast(meta_info.node_type); + (*out)["edge_type"] = py::cast(meta_info.edge_type); + (*out)["node_num"] = py::cast(meta_info.node_num); + (*out)["edge_num"] = py::cast(meta_info.edge_num); + (*out)["node_feature_type"] = py::cast(meta_info.node_feature_type); + (*out)["edge_feature_type"] = py::cast(meta_info.edge_feature_type); + } + } else { + auto error_code = status.error_code(); + RETURN_STATUS_UNEXPECTED(status.error_message() + ". GRPC Code " + std::to_string(error_code)); + } +#endif + return Status::OK(); +} + +#if !defined(_WIN32) && !defined(_WIN64) +Status GraphDataClient::GetGraphData(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response) { + RETURN_IF_NOT_OK(CheckPid()); + void *tag; + bool ok; + grpc::Status status; + grpc::ClientContext ctx; + grpc::CompletionQueue cq; + // One minute timeout + auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60); + ctx.set_deadline(deadline); + std::unique_ptr> rpc( + stub_->PrepareAsyncGetGraphData(&ctx, request, &cq)); + rpc->StartCall(); + rpc->Finish(response, &status, response); + + { + py::gil_scoped_release gil_release; + auto success = cq.Next(&tag, &ok); + CHECK_FAIL_RETURN_UNEXPECTED(success, "Expect successful"); + CHECK_FAIL_RETURN_UNEXPECTED(tag == response, "Expect the same tag"); + CHECK_FAIL_RETURN_UNEXPECTED(ok, "Expect successful"); + } + + if (status.ok()) { + if (response->error_msg() != "Success") { + RETURN_STATUS_UNEXPECTED(response->error_msg()); + } + } else { + auto error_code = status.error_code(); + RETURN_STATUS_UNEXPECTED(status.error_message() + ". GRPC Code " + std::to_string(error_code)); + } + + return Status::OK(); +} + +Status GraphDataClient::GetGraphDataTensor(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response, + std::shared_ptr *out) { + RETURN_IF_NOT_OK(GetGraphData(request, response)); + if (1 == response->result_data().size()) { + const TensorPb &result = response->result_data()[0]; + std::shared_ptr tensor; + RETURN_IF_NOT_OK(PbToTensor(&result, &tensor)); + *out = std::move(tensor); + } else { + RETURN_STATUS_UNEXPECTED("RPC failed: The number of returned tensor is abnormal"); + } + return Status::OK(); +} + +Status GraphDataClient::ParseNodeFeatureFromMemory(const std::shared_ptr &nodes, FeatureType feature_type, + const std::shared_ptr &memory_tensor, + std::shared_ptr *out) { + std::shared_ptr default_feature; + // If no feature can be obtained, fill in the default value + RETURN_IF_NOT_OK(GetNodeDefaultFeature(feature_type, &default_feature)); + TensorShape shape(default_feature->shape()); + auto shape_vec = nodes->shape().AsVector(); + dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies()); + shape = shape.PrependDim(size); + std::shared_ptr fea_tensor; + RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, default_feature->type(), &fea_tensor)); + + dsize_t index = 0; + auto fea_addr_itr = memory_tensor->begin(); + for (auto node_itr = nodes->begin(); node_itr != nodes->end(); ++node_itr) { + int64_t offset = *fea_addr_itr; + fea_addr_itr++; + int64_t len = *fea_addr_itr; + fea_addr_itr++; + if (*node_itr == kDefaultNodeId || offset < 0 || len <= 0) { + RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, default_feature)); + } else { + uchar *start_addr_of_index = nullptr; + TensorShape remaining({-1}); + RETURN_IF_NOT_OK(fea_tensor->StartAddrOfIndex({index}, &start_addr_of_index, &remaining)); + RETURN_IF_NOT_OK(graph_shared_memory_->GetData(start_addr_of_index, len, offset, len)); + } + index++; + } + + TensorShape reshape(nodes->shape()); + for (auto s : default_feature->shape().AsVector()) { + reshape = reshape.AppendDim(s); + } + RETURN_IF_NOT_OK(fea_tensor->Reshape(reshape)); + fea_tensor->Squeeze(); + + *out = std::move(fea_tensor); + return Status::OK(); +} + +Status GraphDataClient::ParseEdgeFeatureFromMemory(const std::shared_ptr &edges, FeatureType feature_type, + const std::shared_ptr &memory_tensor, + std::shared_ptr *out) { + std::shared_ptr default_feature; + // If no feature can be obtained, fill in the default value + RETURN_IF_NOT_OK(GetEdgeDefaultFeature(feature_type, &default_feature)); + TensorShape shape(default_feature->shape()); + auto shape_vec = edges->shape().AsVector(); + dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies()); + shape = shape.PrependDim(size); + std::shared_ptr fea_tensor; + RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, default_feature->type(), &fea_tensor)); + + dsize_t index = 0; + auto fea_addr_itr = memory_tensor->begin(); + for (auto edge_itr = edges->begin(); edge_itr != edges->end(); ++edge_itr) { + int64_t offset = *fea_addr_itr; + fea_addr_itr++; + int64_t len = *fea_addr_itr; + fea_addr_itr++; + if (offset < 0 || len <= 0) { + RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, default_feature)); + } else { + uchar *start_addr_of_index = nullptr; + TensorShape remaining({-1}); + RETURN_IF_NOT_OK(fea_tensor->StartAddrOfIndex({index}, &start_addr_of_index, &remaining)); + RETURN_IF_NOT_OK(graph_shared_memory_->GetData(start_addr_of_index, len, offset, len)); + } + index++; + } + + TensorShape reshape(edges->shape()); + for (auto s : default_feature->shape().AsVector()) { + reshape = reshape.AppendDim(s); + } + RETURN_IF_NOT_OK(fea_tensor->Reshape(reshape)); + fea_tensor->Squeeze(); + + *out = std::move(fea_tensor); + return Status::OK(); +} + +Status GraphDataClient::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature) { + auto itr = default_node_feature_map_.find(feature_type); + if (itr == default_node_feature_map_.end()) { + std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + *out_feature = itr->second; + } + return Status::OK(); +} + +Status GraphDataClient::GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature) { + auto itr = default_edge_feature_map_.find(feature_type); + if (itr == default_edge_feature_map_.end()) { + std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); + RETURN_STATUS_UNEXPECTED(err_msg); + } else { + *out_feature = itr->second; + } + return Status::OK(); +} + +Status GraphDataClient::RegisterToServer() { + RETURN_IF_NOT_OK(CheckPid()); + void *tag; + bool ok; + grpc::Status status; + grpc::ClientContext ctx; + grpc::CompletionQueue cq; + GnnClientRegisterRequestPb request; + GnnClientRegisterResponsePb response; + request.set_pid(static_cast(pid_)); + // One minute timeout + auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60); + ctx.set_deadline(deadline); + std::unique_ptr> rpc( + stub_->PrepareAsyncClientRegister(&ctx, request, &cq)); + rpc->StartCall(); + rpc->Finish(&response, &status, &response); + + { + py::gil_scoped_release gil_release; + auto success = cq.Next(&tag, &ok); + CHECK_FAIL_RETURN_UNEXPECTED(success, "Expect successful"); + CHECK_FAIL_RETURN_UNEXPECTED(tag == &response, "Expect the same tag"); + CHECK_FAIL_RETURN_UNEXPECTED(ok, "Expect successful"); + } + + if (status.ok()) { + if (response.error_msg() == "Success") { + registered_ = true; + data_schema_ = mindrecord::json::parse(response.data_schema()); + shared_memory_key_ = static_cast(response.shared_memory_key()); + shared_memory_size_ = response.shared_memory_size(); + MS_LOG(INFO) << "Register success, recv data_schema:" << response.data_schema(); + for (auto feature_info : response.default_node_feature()) { + std::shared_ptr tensor; + RETURN_IF_NOT_OK(PbToTensor(&feature_info.feature(), &tensor)); + default_node_feature_map_[feature_info.type()] = tensor; + } + for (auto feature_info : response.default_edge_feature()) { + std::shared_ptr tensor; + RETURN_IF_NOT_OK(PbToTensor(&feature_info.feature(), &tensor)); + default_edge_feature_map_[feature_info.type()] = tensor; + } + } else { + RETURN_STATUS_UNEXPECTED(response.error_msg()); + } + } else { + auto error_code = status.error_code(); + RETURN_STATUS_UNEXPECTED(status.error_message() + ". GRPC Code " + std::to_string(error_code)); + } + return Status::OK(); +} + +Status GraphDataClient::UnRegisterToServer() { + RETURN_IF_NOT_OK(CheckPid()); + MS_LOG(INFO) << "Graph data client send unregistered to server "; + void *tag; + bool ok; + grpc::Status status; + grpc::ClientContext ctx; + grpc::CompletionQueue cq; + GnnClientUnRegisterRequestPb request; + GnnClientUnRegisterResponsePb response; + request.set_pid(static_cast(pid_)); + // One minute timeout + auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60); + ctx.set_deadline(deadline); + std::unique_ptr> rpc( + stub_->PrepareAsyncClientUnRegister(&ctx, request, &cq)); + rpc->StartCall(); + rpc->Finish(&response, &status, &response); + { + py::gil_scoped_release gil_release; + auto success = cq.Next(&tag, &ok); + CHECK_FAIL_RETURN_UNEXPECTED(success, "Expect successful"); + CHECK_FAIL_RETURN_UNEXPECTED(tag == &response, "Expect the same tag"); + CHECK_FAIL_RETURN_UNEXPECTED(ok, "Expect successful"); + } + if (status.ok()) { + if (response.error_msg() == "Success") { + MS_LOG(INFO) << "Unregister success."; + registered_ = false; + } else { + RETURN_STATUS_UNEXPECTED(response.error_msg()); + } + } else { + auto error_code = status.error_code(); + RETURN_STATUS_UNEXPECTED(status.error_message() + ". GRPC Code " + std::to_string(error_code)); + } + return Status::OK(); +} + +Status GraphDataClient::InitFeatureParser() { + // get shared memory + graph_shared_memory_ = std::make_unique(shared_memory_size_, shared_memory_key_); + RETURN_IF_NOT_OK(graph_shared_memory_->GetSharedMemory()); + // build feature parser + graph_feature_parser_ = std::make_unique(ShardColumn(data_schema_)); + + return Status::OK(); +} +#endif + +} // namespace gnn +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.h new file mode 100644 index 00000000000..fc0cd58f574 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.h @@ -0,0 +1,185 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#if !defined(_WIN32) && !defined(_WIN64) +#include "proto/gnn_graph_data.grpc.pb.h" +#include "proto/gnn_graph_data.pb.h" +#endif +#include "minddata/dataset/engine/gnn/graph_data.h" +#include "minddata/dataset/engine/gnn/graph_feature_parser.h" +#if !defined(_WIN32) && !defined(_WIN64) +#include "minddata/dataset/engine/gnn/graph_shared_memory.h" +#endif +#include "minddata/mindrecord/include/common/shard_utils.h" +#include "minddata/mindrecord/include/shard_column.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +class GraphDataClient : public GraphData { + public: + // Constructor + // @param std::string dataset_file - + // @param int32_t num_workers - number of parallel threads + GraphDataClient(const std::string &dataset_file, const std::string &hostname, int32_t port); + + ~GraphDataClient(); + + Status Init() override; + + Status Stop() override; + + // Get all nodes from the graph. + // @param NodeType node_type - type of node + // @param std::shared_ptr *out - Returned nodes id + // @return Status - The error code return + Status GetAllNodes(NodeType node_type, std::shared_ptr *out) override; + + // Get all edges from the graph. + // @param NodeType edge_type - type of edge + // @param std::shared_ptr *out - Returned edge ids + // @return Status - The error code return + Status GetAllEdges(EdgeType edge_type, std::shared_ptr *out) override; + + // Get the node id from the edge. + // @param std::vector edge_list - List of edges + // @param std::shared_ptr *out - Returned node ids + // @return Status - The error code return + Status GetNodesFromEdges(const std::vector &edge_list, std::shared_ptr *out) override; + + // All neighbors of the acquisition node. + // @param std::vector node_list - List of nodes + // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported + // @param std::shared_ptr *out - Returned neighbor's id. Because the number of neighbors at different nodes is + // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors + // is not enough, fill in tensor as -1. + // @return Status - The error code return + Status GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, + std::shared_ptr *out) override; + + // Get sampled neighbors. + // @param std::vector node_list - List of nodes + // @param std::vector neighbor_nums - Number of neighbors sampled per hop + // @param std::vector neighbor_types - Neighbor type sampled per hop + // @param std::shared_ptr *out - Returned neighbor's id. + // @return Status - The error code return + Status GetSampledNeighbors(const std::vector &node_list, const std::vector &neighbor_nums, + const std::vector &neighbor_types, std::shared_ptr *out) override; + + // Get negative sampled neighbors. + // @param std::vector node_list - List of nodes + // @param NodeIdType samples_num - Number of neighbors sampled + // @param NodeType neg_neighbor_type - The type of negative neighbor. + // @param std::shared_ptr *out - Returned negative neighbor's id. + // @return Status - The error code return + Status GetNegSampledNeighbors(const std::vector &node_list, NodeIdType samples_num, + NodeType neg_neighbor_type, std::shared_ptr *out) override; + + // Node2vec random walk. + // @param std::vector node_list - List of nodes + // @param std::vector meta_path - node type of each step + // @param float step_home_param - return hyper parameter in node2vec algorithm + // @param float step_away_param - inout hyper parameter in node2vec algorithm + // @param NodeIdType default_node - default node id + // @param std::shared_ptr *out - Returned nodes id in walk path + // @return Status - The error code return + Status RandomWalk(const std::vector &node_list, const std::vector &meta_path, + float step_home_param, float step_away_param, NodeIdType default_node, + std::shared_ptr *out) override; + + // Get the feature of a node + // @param std::shared_ptr nodes - List of nodes + // @param std::vector feature_types - Types of features, An error will be reported if the feature type + // does not exist. + // @param TensorRow *out - Returned features + // @return Status - The error code return + Status GetNodeFeature(const std::shared_ptr &nodes, const std::vector &feature_types, + TensorRow *out) override; + + // Get the feature of a edge + // @param std::shared_ptr edges - List of edges + // @param std::vector feature_types - Types of features, An error will be reported if the feature type + // does not exist. + // @param Tensor *out - Returned features + // @return Status - The error code return + Status GetEdgeFeature(const std::shared_ptr &edges, const std::vector &feature_types, + TensorRow *out) override; + + // Return meta information to python layer + Status GraphInfo(py::dict *out) override; + + private: +#if !defined(_WIN32) && !defined(_WIN64) + Status ParseNodeFeatureFromMemory(const std::shared_ptr &nodes, FeatureType feature_type, + const std::shared_ptr &memory_tensor, std::shared_ptr *out); + + Status ParseEdgeFeatureFromMemory(const std::shared_ptr &edges, FeatureType feature_type, + const std::shared_ptr &memory_tensor, std::shared_ptr *out); + + Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature); + + Status GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature); + + Status GetGraphData(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response); + + Status GetGraphDataTensor(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response, + std::shared_ptr *out); + + Status RegisterToServer(); + + Status UnRegisterToServer(); + + Status InitFeatureParser(); + + Status CheckPid() { + CHECK_FAIL_RETURN_UNEXPECTED(pid_ == getpid(), + "Multi-process mode is not supported, please change to use multi-thread"); + return Status::OK(); + } +#endif + + std::string dataset_file_; + std::string host_; + int32_t port_; + int32_t pid_; + mindrecord::json data_schema_; +#if !defined(_WIN32) && !defined(_WIN64) + std::unique_ptr stub_; + key_t shared_memory_key_; + int64_t shared_memory_size_; + std::unique_ptr graph_feature_parser_; + std::unique_ptr graph_shared_memory_; + std::unordered_map> default_node_feature_map_; + std::unordered_map> default_edge_feature_map_; +#endif + bool registered_; +}; +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.cc similarity index 76% rename from mindspore/ccsrc/minddata/dataset/engine/gnn/graph.cc rename to mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.cc index 7cbfedcf465..a37e92ed4ea 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.cc @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "minddata/dataset/engine/gnn/graph.h" +#include "minddata/dataset/engine/gnn/graph_data_impl.h" #include #include @@ -22,19 +22,25 @@ #include #include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/gnn/graph_loader.h" #include "minddata/dataset/util/random.h" - namespace mindspore { namespace dataset { namespace gnn { -Graph::Graph(std::string dataset_file, int32_t num_workers) - : dataset_file_(dataset_file), num_workers_(num_workers), rnd_(GetRandomDevice()), random_walk_(this) { +GraphDataImpl::GraphDataImpl(std::string dataset_file, int32_t num_workers, bool server_mode) + : dataset_file_(dataset_file), + num_workers_(num_workers), + rnd_(GetRandomDevice()), + random_walk_(this), + server_mode_(server_mode) { rnd_.seed(GetSeed()); MS_LOG(INFO) << "num_workers:" << num_workers; } -Status Graph::GetAllNodes(NodeType node_type, std::shared_ptr *out) { +GraphDataImpl::~GraphDataImpl() {} + +Status GraphDataImpl::GetAllNodes(NodeType node_type, std::shared_ptr *out) { auto itr = node_type_map_.find(node_type); if (itr == node_type_map_.end()) { std::string err_msg = "Invalid node type:" + std::to_string(node_type); @@ -46,8 +52,8 @@ Status Graph::GetAllNodes(NodeType node_type, std::shared_ptr *out) { } template -Status Graph::CreateTensorByVector(const std::vector> &data, DataType type, - std::shared_ptr *out) { +Status GraphDataImpl::CreateTensorByVector(const std::vector> &data, DataType type, + std::shared_ptr *out) { if (!type.IsCompatible()) { RETURN_STATUS_UNEXPECTED("Data type not compatible"); } @@ -72,7 +78,7 @@ Status Graph::CreateTensorByVector(const std::vector> &data, Data } template -Status Graph::ComplementVector(std::vector> *data, size_t max_size, T default_value) { +Status GraphDataImpl::ComplementVector(std::vector> *data, size_t max_size, T default_value) { if (!data || data->empty()) { RETURN_STATUS_UNEXPECTED("Input data is empty"); } @@ -89,7 +95,7 @@ Status Graph::ComplementVector(std::vector> *data, size_t max_siz return Status::OK(); } -Status Graph::GetAllEdges(EdgeType edge_type, std::shared_ptr *out) { +Status GraphDataImpl::GetAllEdges(EdgeType edge_type, std::shared_ptr *out) { auto itr = edge_type_map_.find(edge_type); if (itr == edge_type_map_.end()) { std::string err_msg = "Invalid edge type:" + std::to_string(edge_type); @@ -100,7 +106,7 @@ Status Graph::GetAllEdges(EdgeType edge_type, std::shared_ptr *out) { return Status::OK(); } -Status Graph::GetNodesFromEdges(const std::vector &edge_list, std::shared_ptr *out) { +Status GraphDataImpl::GetNodesFromEdges(const std::vector &edge_list, std::shared_ptr *out) { if (edge_list.empty()) { RETURN_STATUS_UNEXPECTED("Input edge_list is empty"); } @@ -122,8 +128,8 @@ Status Graph::GetNodesFromEdges(const std::vector &edge_list, std::s return Status::OK(); } -Status Graph::GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, - std::shared_ptr *out) { +Status GraphDataImpl::GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, + std::shared_ptr *out) { CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); RETURN_IF_NOT_OK(CheckNeighborType(neighbor_type)); @@ -143,7 +149,7 @@ Status Graph::GetAllNeighbors(const std::vector &node_list, NodeType return Status::OK(); } -Status Graph::CheckSamplesNum(NodeIdType samples_num) { +Status GraphDataImpl::CheckSamplesNum(NodeIdType samples_num) { NodeIdType all_nodes_number = std::accumulate(node_type_map_.begin(), node_type_map_.end(), 0, [](NodeIdType t1, const auto &t2) -> NodeIdType { return t1 + t2.second.size(); }); @@ -155,7 +161,7 @@ Status Graph::CheckSamplesNum(NodeIdType samples_num) { return Status::OK(); } -Status Graph::CheckNeighborType(NodeType neighbor_type) { +Status GraphDataImpl::CheckNeighborType(NodeType neighbor_type) { if (node_type_map_.find(neighbor_type) == node_type_map_.end()) { std::string err_msg = "Invalid neighbor type:" + std::to_string(neighbor_type); RETURN_STATUS_UNEXPECTED(err_msg); @@ -163,9 +169,9 @@ Status Graph::CheckNeighborType(NodeType neighbor_type) { return Status::OK(); } -Status Graph::GetSampledNeighbors(const std::vector &node_list, - const std::vector &neighbor_nums, - const std::vector &neighbor_types, std::shared_ptr *out) { +Status GraphDataImpl::GetSampledNeighbors(const std::vector &node_list, + const std::vector &neighbor_nums, + const std::vector &neighbor_types, std::shared_ptr *out) { CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(), "The sizes of neighbor_nums and neighbor_types are inconsistent."); @@ -205,8 +211,9 @@ Status Graph::GetSampledNeighbors(const std::vector &node_list, return Status::OK(); } -Status Graph::NegativeSample(const std::vector &data, const std::unordered_set &exclude_data, - int32_t samples_num, std::vector *out_samples) { +Status GraphDataImpl::NegativeSample(const std::vector &data, + const std::unordered_set &exclude_data, int32_t samples_num, + std::vector *out_samples) { CHECK_FAIL_RETURN_UNEXPECTED(!data.empty(), "Input data is empty."); std::vector shuffled_id(data.size()); std::iota(shuffled_id.begin(), shuffled_id.end(), 0); @@ -223,8 +230,8 @@ Status Graph::NegativeSample(const std::vector &data, const std::uno return Status::OK(); } -Status Graph::GetNegSampledNeighbors(const std::vector &node_list, NodeIdType samples_num, - NodeType neg_neighbor_type, std::shared_ptr *out) { +Status GraphDataImpl::GetNegSampledNeighbors(const std::vector &node_list, NodeIdType samples_num, + NodeType neg_neighbor_type, std::shared_ptr *out) { CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); RETURN_IF_NOT_OK(CheckSamplesNum(samples_num)); RETURN_IF_NOT_OK(CheckNeighborType(neg_neighbor_type)); @@ -260,9 +267,9 @@ Status Graph::GetNegSampledNeighbors(const std::vector &node_list, N return Status::OK(); } -Status Graph::RandomWalk(const std::vector &node_list, const std::vector &meta_path, - float step_home_param, float step_away_param, NodeIdType default_node, - std::shared_ptr *out) { +Status GraphDataImpl::RandomWalk(const std::vector &node_list, const std::vector &meta_path, + float step_home_param, float step_away_param, NodeIdType default_node, + std::shared_ptr *out) { RETURN_IF_NOT_OK(random_walk_.Build(node_list, meta_path, step_home_param, step_away_param, default_node)); std::vector> walks; RETURN_IF_NOT_OK(random_walk_.SimulateWalk(&walks)); @@ -270,7 +277,7 @@ Status Graph::RandomWalk(const std::vector &node_list, const std::ve return Status::OK(); } -Status Graph::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature) { +Status GraphDataImpl::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature) { auto itr = default_node_feature_map_.find(feature_type); if (itr == default_node_feature_map_.end()) { std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); @@ -281,7 +288,7 @@ Status Graph::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature) { +Status GraphDataImpl::GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr *out_feature) { auto itr = default_edge_feature_map_.find(feature_type); if (itr == default_edge_feature_map_.end()) { std::string err_msg = "Invalid feature type:" + std::to_string(feature_type); @@ -292,8 +299,8 @@ Status Graph::GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr &nodes, const std::vector &feature_types, - TensorRow *out) { +Status GraphDataImpl::GetNodeFeature(const std::shared_ptr &nodes, + const std::vector &feature_types, TensorRow *out) { if (!nodes || nodes->Size() == 0) { RETURN_STATUS_UNEXPECTED("Input nodes is empty"); } @@ -339,8 +346,49 @@ Status Graph::GetNodeFeature(const std::shared_ptr &nodes, const std::ve return Status::OK(); } -Status Graph::GetEdgeFeature(const std::shared_ptr &edges, const std::vector &feature_types, - TensorRow *out) { +Status GraphDataImpl::GetNodeFeatureSharedMemory(const std::shared_ptr &nodes, FeatureType type, + std::shared_ptr *out) { + if (!nodes || nodes->Size() == 0) { + RETURN_STATUS_UNEXPECTED("Input nodes is empty"); + } + TensorShape shape = nodes->shape().AppendDim(2); + std::shared_ptr fea_tensor; + RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, DataType(DataType::DE_INT64), &fea_tensor)); + + auto out_fea_itr = fea_tensor->begin(); + for (auto node_itr = nodes->begin(); node_itr != nodes->end(); ++node_itr) { + if (*node_itr == kDefaultNodeId) { + *out_fea_itr = -1; + ++out_fea_itr; + *out_fea_itr = -1; + ++out_fea_itr; + } else { + std::shared_ptr node; + RETURN_IF_NOT_OK(GetNodeByNodeId(*node_itr, &node)); + std::shared_ptr feature; + if (!node->GetFeatures(type, &feature).IsOk()) { + *out_fea_itr = -1; + ++out_fea_itr; + *out_fea_itr = -1; + ++out_fea_itr; + } else { + for (auto fea_itr = feature->Value()->begin(); fea_itr != feature->Value()->end(); + ++fea_itr) { + *out_fea_itr = *fea_itr; + ++out_fea_itr; + } + } + } + } + + fea_tensor->Squeeze(); + + *out = std::move(fea_tensor); + return Status::OK(); +} + +Status GraphDataImpl::GetEdgeFeature(const std::shared_ptr &edges, + const std::vector &feature_types, TensorRow *out) { if (!edges || edges->Size() == 0) { RETURN_STATUS_UNEXPECTED("Input edges is empty"); } @@ -382,12 +430,45 @@ Status Graph::GetEdgeFeature(const std::shared_ptr &edges, const std::ve return Status::OK(); } -Status Graph::Init() { +Status GraphDataImpl::GetEdgeFeatureSharedMemory(const std::shared_ptr &edges, FeatureType type, + std::shared_ptr *out) { + if (!edges || edges->Size() == 0) { + RETURN_STATUS_UNEXPECTED("Input edges is empty"); + } + TensorShape shape = edges->shape().AppendDim(2); + std::shared_ptr fea_tensor; + RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, DataType(DataType::DE_INT64), &fea_tensor)); + + auto out_fea_itr = fea_tensor->begin(); + for (auto edge_itr = edges->begin(); edge_itr != edges->end(); ++edge_itr) { + std::shared_ptr edge; + RETURN_IF_NOT_OK(GetEdgeByEdgeId(*edge_itr, &edge)); + std::shared_ptr feature; + if (!edge->GetFeatures(type, &feature).IsOk()) { + *out_fea_itr = -1; + ++out_fea_itr; + *out_fea_itr = -1; + ++out_fea_itr; + } else { + for (auto fea_itr = feature->Value()->begin(); fea_itr != feature->Value()->end(); ++fea_itr) { + *out_fea_itr = *fea_itr; + ++out_fea_itr; + } + } + } + + fea_tensor->Squeeze(); + + *out = std::move(fea_tensor); + return Status::OK(); +} + +Status GraphDataImpl::Init() { RETURN_IF_NOT_OK(LoadNodeAndEdge()); return Status::OK(); } -Status Graph::GetMetaInfo(MetaInfo *meta_info) { +Status GraphDataImpl::GetMetaInfo(MetaInfo *meta_info) { meta_info->node_type.resize(node_type_map_.size()); std::transform(node_type_map_.begin(), node_type_map_.end(), meta_info->node_type.begin(), [](auto itr) { return itr.first; }); @@ -427,7 +508,7 @@ Status Graph::GetMetaInfo(MetaInfo *meta_info) { } #ifdef ENABLE_PYTHON -Status Graph::GraphInfo(py::dict *out) { +Status GraphDataImpl::GraphInfo(py::dict *out) { MetaInfo meta_info; RETURN_IF_NOT_OK(GetMetaInfo(&meta_info)); (*out)["node_type"] = py::cast(meta_info.node_type); @@ -440,18 +521,16 @@ Status Graph::GraphInfo(py::dict *out) { } #endif -Status Graph::LoadNodeAndEdge() { - GraphLoader gl(dataset_file_, num_workers_); +Status GraphDataImpl::LoadNodeAndEdge() { + GraphLoader gl(this, dataset_file_, num_workers_, server_mode_); // ask graph_loader to load everything into memory RETURN_IF_NOT_OK(gl.InitAndLoad()); // get all maps - RETURN_IF_NOT_OK(gl.GetNodesAndEdges(&node_id_map_, &edge_id_map_, &node_type_map_, &edge_type_map_, - &node_feature_map_, &edge_feature_map_, &default_node_feature_map_, - &default_edge_feature_map_)); + RETURN_IF_NOT_OK(gl.GetNodesAndEdges()); return Status::OK(); } -Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr *node) { +Status GraphDataImpl::GetNodeByNodeId(NodeIdType id, std::shared_ptr *node) { auto itr = node_id_map_.find(id); if (itr == node_id_map_.end()) { std::string err_msg = "Invalid node id:" + std::to_string(id); @@ -462,7 +541,7 @@ Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr *node) { return Status::OK(); } -Status Graph::GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr *edge) { +Status GraphDataImpl::GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr *edge) { auto itr = edge_id_map_.find(id); if (itr == edge_id_map_.end()) { std::string err_msg = "Invalid edge id:" + std::to_string(id); @@ -473,12 +552,13 @@ Status Graph::GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr *edge) { return Status::OK(); } -Graph::RandomWalkBase::RandomWalkBase(Graph *graph) +GraphDataImpl::RandomWalkBase::RandomWalkBase(GraphDataImpl *graph) : graph_(graph), step_home_param_(1.0), step_away_param_(1.0), default_node_(-1), num_walks_(1), num_workers_(1) {} -Status Graph::RandomWalkBase::Build(const std::vector &node_list, const std::vector &meta_path, - float step_home_param, float step_away_param, const NodeIdType default_node, - int32_t num_walks, int32_t num_workers) { +Status GraphDataImpl::RandomWalkBase::Build(const std::vector &node_list, + const std::vector &meta_path, float step_home_param, + float step_away_param, const NodeIdType default_node, int32_t num_walks, + int32_t num_workers) { CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); node_list_ = node_list; if (meta_path.empty() || meta_path.size() > kMaxNumWalks) { @@ -516,7 +596,7 @@ Status Graph::RandomWalkBase::Build(const std::vector &node_list, co return Status::OK(); } -Status Graph::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::vector *walk_path) { +Status GraphDataImpl::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::vector *walk_path) { // Simulate a random walk starting from start node. auto walk = std::vector(1, start_node); // walk is an vector // walk simulate @@ -556,8 +636,8 @@ Status Graph::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::ve return Status::OK(); } -Status Graph::RandomWalkBase::SimulateWalk(std::vector> *walks) { - for (int32_t i = 0; i < num_walks_; i++) { +Status GraphDataImpl::RandomWalkBase::SimulateWalk(std::vector> *walks) { + for (int32_t i = 0; i < num_walks_; ++i) { for (const auto &node : node_list_) { std::vector walk; RETURN_IF_NOT_OK(Node2vecWalk(node, &walk)); @@ -567,8 +647,8 @@ Status Graph::RandomWalkBase::SimulateWalk(std::vector> return Status::OK(); } -Status Graph::RandomWalkBase::GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type, - std::shared_ptr *node_probability) { +Status GraphDataImpl::RandomWalkBase::GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type, + std::shared_ptr *node_probability) { // Generate alias nodes std::shared_ptr node; graph_->GetNodeByNodeId(node_id, &node); @@ -581,8 +661,9 @@ Status Graph::RandomWalkBase::GetNodeProbability(const NodeIdType &node_id, cons return Status::OK(); } -Status Graph::RandomWalkBase::GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index, - std::shared_ptr *edge_probability) { +Status GraphDataImpl::RandomWalkBase::GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, + uint32_t meta_path_index, + std::shared_ptr *edge_probability) { // Get the alias edge setup lists for a given edge. std::shared_ptr src_node; graph_->GetNodeByNodeId(src, &src_node); @@ -616,7 +697,7 @@ Status Graph::RandomWalkBase::GetEdgeProbability(const NodeIdType &src, const No return Status::OK(); } -StochasticIndex Graph::RandomWalkBase::GenerateProbability(const std::vector &probability) { +StochasticIndex GraphDataImpl::RandomWalkBase::GenerateProbability(const std::vector &probability) { uint32_t K = probability.size(); std::vector switch_to_large_index(K, 0); std::vector weight(K, .0); @@ -644,7 +725,7 @@ StochasticIndex Graph::RandomWalkBase::GenerateProbability(const std::vector -std::vector Graph::RandomWalkBase::Normalize(const std::vector &non_normalized_probability) { +std::vector GraphDataImpl::RandomWalkBase::Normalize(const std::vector &non_normalized_probability) { float sum_probability = 1.0 * std::accumulate(non_normalized_probability.begin(), non_normalized_probability.end(), 0); if (sum_probability < kGnnEpsilon) { diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h similarity index 81% rename from mindspore/ccsrc/minddata/dataset/engine/gnn/graph.h rename to mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h index cb755b0bed9..d596e99a2d9 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h @@ -13,8 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_ -#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_IMPL_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_IMPL_H_ #include #include @@ -25,13 +25,11 @@ #include #include -#include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/core/tensor_row.h" -#include "minddata/dataset/engine/gnn/graph_loader.h" -#include "minddata/dataset/engine/gnn/feature.h" -#include "minddata/dataset/engine/gnn/node.h" -#include "minddata/dataset/engine/gnn/edge.h" -#include "minddata/dataset/util/status.h" +#include "minddata/dataset/engine/gnn/graph_data.h" +#if !defined(_WIN32) && !defined(_WIN64) +#include "minddata/dataset/engine/gnn/graph_shared_memory.h" +#endif +#include "minddata/mindrecord/include/common/shard_utils.h" namespace mindspore { namespace dataset { @@ -41,41 +39,32 @@ const float kGnnEpsilon = 0.0001; const uint32_t kMaxNumWalks = 80; using StochasticIndex = std::pair, std::vector>; -struct MetaInfo { - std::vector node_type; - std::vector edge_type; - std::map node_num; - std::map edge_num; - std::vector node_feature_type; - std::vector edge_feature_type; -}; - -class Graph { +class GraphDataImpl : public GraphData { public: // Constructor // @param std::string dataset_file - // @param int32_t num_workers - number of parallel threads - Graph(std::string dataset_file, int32_t num_workers); + GraphDataImpl(std::string dataset_file, int32_t num_workers, bool server_mode = false); - ~Graph() = default; + ~GraphDataImpl(); // Get all nodes from the graph. // @param NodeType node_type - type of node // @param std::shared_ptr *out - Returned nodes id // @return Status - The error code return - Status GetAllNodes(NodeType node_type, std::shared_ptr *out); + Status GetAllNodes(NodeType node_type, std::shared_ptr *out) override; // Get all edges from the graph. // @param NodeType edge_type - type of edge // @param std::shared_ptr *out - Returned edge ids // @return Status - The error code return - Status GetAllEdges(EdgeType edge_type, std::shared_ptr *out); + Status GetAllEdges(EdgeType edge_type, std::shared_ptr *out) override; // Get the node id from the edge. // @param std::vector edge_list - List of edges // @param std::shared_ptr *out - Returned node ids // @return Status - The error code return - Status GetNodesFromEdges(const std::vector &edge_list, std::shared_ptr *out); + Status GetNodesFromEdges(const std::vector &edge_list, std::shared_ptr *out) override; // All neighbors of the acquisition node. // @param std::vector node_list - List of nodes @@ -85,7 +74,7 @@ class Graph { // is not enough, fill in tensor as -1. // @return Status - The error code return Status GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, - std::shared_ptr *out); + std::shared_ptr *out) override; // Get sampled neighbors. // @param std::vector node_list - List of nodes @@ -94,7 +83,7 @@ class Graph { // @param std::shared_ptr *out - Returned neighbor's id. // @return Status - The error code return Status GetSampledNeighbors(const std::vector &node_list, const std::vector &neighbor_nums, - const std::vector &neighbor_types, std::shared_ptr *out); + const std::vector &neighbor_types, std::shared_ptr *out) override; // Get negative sampled neighbors. // @param std::vector node_list - List of nodes @@ -103,7 +92,7 @@ class Graph { // @param std::shared_ptr *out - Returned negative neighbor's id. // @return Status - The error code return Status GetNegSampledNeighbors(const std::vector &node_list, NodeIdType samples_num, - NodeType neg_neighbor_type, std::shared_ptr *out); + NodeType neg_neighbor_type, std::shared_ptr *out) override; // Node2vec random walk. // @param std::vector node_list - List of nodes @@ -115,7 +104,7 @@ class Graph { // @return Status - The error code return Status RandomWalk(const std::vector &node_list, const std::vector &meta_path, float step_home_param, float step_away_param, NodeIdType default_node, - std::shared_ptr *out); + std::shared_ptr *out) override; // Get the feature of a node // @param std::shared_ptr nodes - List of nodes @@ -124,16 +113,22 @@ class Graph { // @param TensorRow *out - Returned features // @return Status - The error code return Status GetNodeFeature(const std::shared_ptr &nodes, const std::vector &feature_types, - TensorRow *out); + TensorRow *out) override; + + Status GetNodeFeatureSharedMemory(const std::shared_ptr &nodes, FeatureType type, + std::shared_ptr *out); // Get the feature of a edge - // @param std::shared_ptr edget - List of edges + // @param std::shared_ptr edges - List of edges // @param std::vector feature_types - Types of features, An error will be reported if the feature type // does not exist. // @param Tensor *out - Returned features // @return Status - The error code return - Status GetEdgeFeature(const std::shared_ptr &edget, const std::vector &feature_types, - TensorRow *out); + Status GetEdgeFeature(const std::shared_ptr &edges, const std::vector &feature_types, + TensorRow *out) override; + + Status GetEdgeFeatureSharedMemory(const std::shared_ptr &edges, FeatureType type, + std::shared_ptr *out); // Get meta information of graph // @param MetaInfo *meta_info - Returned meta information @@ -142,15 +137,34 @@ class Graph { #ifdef ENABLE_PYTHON // Return meta information to python layer - Status GraphInfo(py::dict *out); + Status GraphInfo(py::dict *out) override; #endif - Status Init(); + const std::unordered_map> *GetAllDefaultNodeFeatures() { + return &default_node_feature_map_; + } + + const std::unordered_map> *GetAllDefaultEdgeFeatures() { + return &default_edge_feature_map_; + } + + Status Init() override; + + Status Stop() override { return Status::OK(); } + + std::string GetDataSchema() { return data_schema_.dump(); } + +#if !defined(_WIN32) && !defined(_WIN64) + key_t GetSharedMemoryKey() { return graph_shared_memory_->memory_key(); } + + int64_t GetSharedMemorySize() { return graph_shared_memory_->memory_size(); } +#endif private: + friend class GraphLoader; class RandomWalkBase { public: - explicit RandomWalkBase(Graph *graph); + explicit RandomWalkBase(GraphDataImpl *graph); Status Build(const std::vector &node_list, const std::vector &meta_path, float step_home_param = 1.0, float step_away_param = 1.0, NodeIdType default_node = -1, @@ -176,7 +190,7 @@ class Graph { template std::vector Normalize(const std::vector &non_normalized_probability); - Graph *graph_; + GraphDataImpl *graph_; std::vector node_list_; std::vector meta_path_; float step_home_param_; // Return hyper parameter. Default is 1.0 @@ -248,7 +262,11 @@ class Graph { int32_t num_workers_; // The number of worker threads std::mt19937 rnd_; RandomWalkBase random_walk_; - + mindrecord::json data_schema_; + bool server_mode_; +#if !defined(_WIN32) && !defined(_WIN64) + std::unique_ptr graph_shared_memory_; +#endif std::unordered_map> node_type_map_; std::unordered_map> node_id_map_; @@ -264,4 +282,4 @@ class Graph { } // namespace gnn } // namespace dataset } // namespace mindspore -#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_ +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_IMPL_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_server.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_server.cc new file mode 100644 index 00000000000..2658ba7256d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_server.cc @@ -0,0 +1,133 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/gnn/graph_data_server.h" + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor_shape.h" +#include "minddata/dataset/engine/gnn/graph_data_impl.h" +#include "minddata/dataset/util/random.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +GraphDataServer::GraphDataServer(const std::string &dataset_file, int32_t num_workers, const std::string &hostname, + int32_t port, int32_t client_num, bool auto_shutdown) + : dataset_file_(dataset_file), + num_workers_(num_workers), + client_num_(client_num), + max_connected_client_num_(0), + auto_shutdown_(auto_shutdown), + state_(kGdsUninit) { + tg_ = std::make_unique(); + graph_data_impl_ = std::make_unique(dataset_file, num_workers, true); +#if !defined(_WIN32) && !defined(_WIN64) + service_impl_ = std::make_unique(this, graph_data_impl_.get()); + async_server_ = std::make_unique(hostname, port, service_impl_.get()); +#endif +} + +Status GraphDataServer::Init() { +#if defined(_WIN32) || defined(_WIN64) + RETURN_STATUS_UNEXPECTED("Graph data server is not supported in Windows OS"); +#else + set_state(kGdsInitializing); + RETURN_IF_NOT_OK(async_server_->Run()); + // RETURN_IF_NOT_OK(InitGraphDataImpl()); + RETURN_IF_NOT_OK(tg_->CreateAsyncTask("init graph data impl", std::bind(&GraphDataServer::InitGraphDataImpl, this))); + for (int32_t i = 0; i < num_workers_; ++i) { + RETURN_IF_NOT_OK( + tg_->CreateAsyncTask("start async rpc service", std::bind(&GraphDataServer::StartAsyncRpcService, this))); + } + if (auto_shutdown_) { + RETURN_IF_NOT_OK( + tg_->CreateAsyncTask("judge auto shutdown server", std::bind(&GraphDataServer::JudgeAutoShutdownServer, this))); + } + return Status::OK(); +#endif +} + +Status GraphDataServer::InitGraphDataImpl() { + TaskManager::FindMe()->Post(); + Status s = graph_data_impl_->Init(); + if (s.IsOk()) { + set_state(kGdsRunning); + } else { + (void)Stop(); + } + return s; +} + +#if !defined(_WIN32) && !defined(_WIN64) +Status GraphDataServer::StartAsyncRpcService() { + TaskManager::FindMe()->Post(); + RETURN_IF_NOT_OK(async_server_->HandleRequest()); + return Status::OK(); +} +#endif + +Status GraphDataServer::JudgeAutoShutdownServer() { + TaskManager::FindMe()->Post(); + while (true) { + if (auto_shutdown_ && (max_connected_client_num_ >= client_num_) && (client_pid_.size() == 0)) { + MS_LOG(INFO) << "All clients have been unregister, automatically exit the server."; + RETURN_IF_NOT_OK(Stop()); + break; + } + if (state_ == kGdsStopped) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } + return Status::OK(); +} + +Status GraphDataServer::Stop() { +#if !defined(_WIN32) && !defined(_WIN64) + async_server_->Stop(); +#endif + set_state(kGdsStopped); + graph_data_impl_.reset(); + return Status::OK(); +} + +Status GraphDataServer::ClientRegister(int32_t pid) { + std::unique_lock lck(mutex_); + MS_LOG(INFO) << "client register pid:" << std::to_string(pid); + client_pid_.emplace(pid); + if (client_pid_.size() > max_connected_client_num_) { + max_connected_client_num_ = client_pid_.size(); + } + return Status::OK(); +} +Status GraphDataServer::ClientUnRegister(int32_t pid) { + std::unique_lock lck(mutex_); + auto itr = client_pid_.find(pid); + if (itr != client_pid_.end()) { + client_pid_.erase(itr); + MS_LOG(INFO) << "client unregister pid:" << std::to_string(pid); + } + return Status::OK(); +} + +} // namespace gnn +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_server.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_server.h new file mode 100644 index 00000000000..ee37661d711 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_server.h @@ -0,0 +1,196 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_ + +#include +#include +#include +#include + +#if !defined(_WIN32) && !defined(_WIN64) +#include "grpcpp/grpcpp.h" +#include "minddata/dataset/engine/gnn/graph_data_service_impl.h" +#include "minddata/dataset/engine/gnn/grpc_async_server.h" +#endif +#include "minddata/dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +class GraphDataImpl; + +class GraphDataServer { + public: + enum ServerState { kGdsUninit = 0, kGdsInitializing, kGdsRunning, kGdsStopped }; + GraphDataServer(const std::string &dataset_file, int32_t num_workers, const std::string &hostname, int32_t port, + int32_t client_num, bool auto_shutdown); + ~GraphDataServer() = default; + + Status Init(); + + Status Stop(); + + Status ClientRegister(int32_t pid); + Status ClientUnRegister(int32_t pid); + + enum ServerState state() { return state_; } + + bool IsStoped() { + if (state_ == kGdsStopped) { + return true; + } else { + return false; + } + } + + private: + void set_state(enum ServerState state) { state_ = state; } + + Status InitGraphDataImpl(); +#if !defined(_WIN32) && !defined(_WIN64) + Status StartAsyncRpcService(); +#endif + Status JudgeAutoShutdownServer(); + + std::string dataset_file_; + int32_t num_workers_; // The number of worker threads + int32_t client_num_; + int32_t max_connected_client_num_; + bool auto_shutdown_; + enum ServerState state_; + std::unique_ptr tg_; // Class for worker management + std::unique_ptr graph_data_impl_; + std::unordered_set client_pid_; + std::mutex mutex_; +#if !defined(_WIN32) && !defined(_WIN64) + std::unique_ptr service_impl_; + std::unique_ptr async_server_; +#endif +}; + +#if !defined(_WIN32) && !defined(_WIN64) +class UntypedCall { + public: + virtual ~UntypedCall() {} + + virtual Status operator()() = 0; +}; + +template +class CallData : public UntypedCall { + public: + enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 }; + using EnqueueFunction = void (AsyncService::*)(grpc::ServerContext *, RequestMessage *, + grpc::ServerAsyncResponseWriter *, + grpc::CompletionQueue *, grpc::ServerCompletionQueue *, void *); + using HandleRequestFunction = grpc::Status (ServiceImpl::*)(grpc::ServerContext *, const RequestMessage *, + ResponseMessage *); + CallData(ServiceImpl *service_impl, AsyncService *async_service, grpc::ServerCompletionQueue *cq, + EnqueueFunction enqueue_function, HandleRequestFunction handle_request_function) + : status_(STATE::CREATE), + service_impl_(service_impl), + async_service_(async_service), + cq_(cq), + enqueue_function_(enqueue_function), + handle_request_function_(handle_request_function), + responder_(&ctx_) {} + + ~CallData() = default; + + static Status EnqueueRequest(ServiceImpl *service_impl, AsyncService *async_service, grpc::ServerCompletionQueue *cq, + EnqueueFunction enqueue_function, HandleRequestFunction handle_request_function) { + auto call = new CallData( + service_impl, async_service, cq, enqueue_function, handle_request_function); + RETURN_IF_NOT_OK((*call)()); + return Status::OK(); + } + + Status operator()() { + if (status_ == STATE::CREATE) { + status_ = STATE::PROCESS; + (async_service_->*enqueue_function_)(&ctx_, &request_, &responder_, cq_, cq_, this); + } else if (status_ == STATE::PROCESS) { + EnqueueRequest(service_impl_, async_service_, cq_, enqueue_function_, handle_request_function_); + status_ = STATE::FINISH; + // new CallData(service_, cq_, this->s_type_); + grpc::Status s = (service_impl_->*handle_request_function_)(&ctx_, &request_, &response_); + responder_.Finish(response_, s, this); + } else { + GPR_ASSERT(status_ == STATE::FINISH); + delete this; + } + return Status::OK(); + } + + private: + STATE status_; + ServiceImpl *service_impl_; + AsyncService *async_service_; + grpc::ServerCompletionQueue *cq_; + EnqueueFunction enqueue_function_; + HandleRequestFunction handle_request_function_; + grpc::ServerContext ctx_; + grpc::ServerAsyncResponseWriter responder_; + RequestMessage request_; + ResponseMessage response_; +}; + +#define ENQUEUE_REQUEST(service_impl, async_service, cq, method, request_msg, response_msg) \ + do { \ + Status s = \ + CallData::EnqueueRequest( \ + service_impl, async_service, cq, &GnnGraphData::AsyncService::Request##method, \ + &gnn::GraphDataServiceImpl::method); \ + RETURN_IF_NOT_OK(s); \ + } while (0) + +class GraphDataGrpcServer : public GrpcAsyncServer { + public: + GraphDataGrpcServer(const std::string &host, int32_t port, GraphDataServiceImpl *service_impl) + : GrpcAsyncServer(host, port), service_impl_(service_impl) {} + + Status RegisterService(grpc::ServerBuilder *builder) { + builder->RegisterService(&svc_); + return Status::OK(); + } + + Status EnqueueRequest() { + ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), ClientRegister, GnnClientRegisterRequestPb, + GnnClientRegisterResponsePb); + ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), ClientUnRegister, GnnClientUnRegisterRequestPb, + GnnClientUnRegisterResponsePb); + ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), GetGraphData, GnnGraphDataRequestPb, GnnGraphDataResponsePb); + ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), GetMetaInfo, GnnMetaInfoRequestPb, GnnMetaInfoResponsePb); + return Status::OK(); + } + + Status ProcessRequest(void *tag) { + auto rq = static_cast(tag); + RETURN_IF_NOT_OK((*rq)()); + return Status::OK(); + } + + private: + GraphDataServiceImpl *service_impl_; + GnnGraphData::AsyncService svc_; +}; +#endif +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.cc new file mode 100644 index 00000000000..b926186870d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.cc @@ -0,0 +1,299 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/gnn/graph_data_service_impl.h" + +#include +#include +#include + +#include "minddata/dataset/engine/gnn/tensor_proto.h" +#include "minddata/dataset/engine/gnn/graph_data_server.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +using pFunction = Status (GraphDataServiceImpl::*)(const GnnGraphDataRequestPb *, GnnGraphDataResponsePb *); +static std::unordered_map g_get_graph_data_func_ = { + {GET_ALL_NODES, &GraphDataServiceImpl::GetAllNodes}, + {GET_ALL_EDGES, &GraphDataServiceImpl::GetAllEdges}, + {GET_NODES_FROM_EDGES, &GraphDataServiceImpl::GetNodesFromEdges}, + {GET_ALL_NEIGHBORS, &GraphDataServiceImpl::GetAllNeighbors}, + {GET_SAMPLED_NEIGHBORS, &GraphDataServiceImpl::GetSampledNeighbors}, + {GET_NEG_SAMPLED_NEIGHBORS, &GraphDataServiceImpl::GetNegSampledNeighbors}, + {RANDOM_WALK, &GraphDataServiceImpl::RandomWalk}, + {GET_NODE_FEATURE, &GraphDataServiceImpl::GetNodeFeature}, + {GET_EDGE_FEATURE, &GraphDataServiceImpl::GetEdgeFeature}}; + +GraphDataServiceImpl::GraphDataServiceImpl(GraphDataServer *server, GraphDataImpl *graph_data_impl) + : server_(server), graph_data_impl_(graph_data_impl) {} + +Status GraphDataServiceImpl::FillDefaultFeature(GnnClientRegisterResponsePb *response) { + const auto default_node_features = graph_data_impl_->GetAllDefaultNodeFeatures(); + for (const auto feature : *default_node_features) { + GnnFeatureInfoPb *feature_info = response->add_default_node_feature(); + feature_info->set_type(feature.first); + RETURN_IF_NOT_OK(TensorToPb(feature.second->Value(), feature_info->mutable_feature())); + } + const auto default_edge_features = graph_data_impl_->GetAllDefaultEdgeFeatures(); + for (const auto feature : *default_edge_features) { + GnnFeatureInfoPb *feature_info = response->add_default_edge_feature(); + feature_info->set_type(feature.first); + RETURN_IF_NOT_OK(TensorToPb(feature.second->Value(), feature_info->mutable_feature())); + } + return Status::OK(); +} + +grpc::Status GraphDataServiceImpl::ClientRegister(grpc::ServerContext *context, + const GnnClientRegisterRequestPb *request, + GnnClientRegisterResponsePb *response) { + Status s = server_->ClientRegister(request->pid()); + if (s.IsOk()) { + switch (server_->state()) { + case GraphDataServer::kGdsUninit: + case GraphDataServer::kGdsInitializing: + response->set_error_msg("Initializing"); + break; + case GraphDataServer::kGdsRunning: + response->set_error_msg("Success"); + response->set_data_schema(graph_data_impl_->GetDataSchema()); + response->set_shared_memory_key(graph_data_impl_->GetSharedMemoryKey()); + response->set_shared_memory_size(graph_data_impl_->GetSharedMemorySize()); + s = FillDefaultFeature(response); + if (!s.IsOk()) { + response->set_error_msg(s.ToString()); + } + break; + case GraphDataServer::kGdsStopped: + response->set_error_msg("Stoped"); + break; + } + } else { + response->set_error_msg(s.ToString()); + } + return ::grpc::Status::OK; +} + +grpc::Status GraphDataServiceImpl::ClientUnRegister(grpc::ServerContext *context, + const GnnClientUnRegisterRequestPb *request, + GnnClientUnRegisterResponsePb *response) { + Status s = server_->ClientUnRegister(request->pid()); + if (s.IsOk()) { + response->set_error_msg("Success"); + } else { + response->set_error_msg(s.ToString()); + } + return ::grpc::Status::OK; +} + +grpc::Status GraphDataServiceImpl::GetGraphData(grpc::ServerContext *context, const GnnGraphDataRequestPb *request, + GnnGraphDataResponsePb *response) { + // MS_LOG(INFO) << "#### receive GetGraphData:" << request->op_name(); + Status s; + auto iter = g_get_graph_data_func_.find(request->op_name()); + if (iter != g_get_graph_data_func_.end()) { + pFunction func = iter->second; + s = (this->*func)(request, response); + if (s.IsOk()) { + response->set_error_msg("Success"); + } else { + response->set_error_msg(s.ToString()); + } + } else { + response->set_error_msg("Invalid op name."); + } + // MS_LOG(INFO) << "#### end receive GetGraphData:" << request->op_name(); + return ::grpc::Status::OK; +} + +grpc::Status GraphDataServiceImpl::GetMetaInfo(grpc::ServerContext *context, const GnnMetaInfoRequestPb *request, + GnnMetaInfoResponsePb *response) { + MetaInfo meta_info; + Status s = graph_data_impl_->GetMetaInfo(&meta_info); + if (s.IsOk()) { + response->set_error_msg("Success"); + for (const auto &type : meta_info.node_type) { + auto node_info = response->add_node_info(); + node_info->set_type(static_cast(type)); + auto itr = meta_info.node_num.find(type); + if (itr != meta_info.node_num.end()) { + node_info->set_num(static_cast(itr->second)); + } else { + node_info->set_num(0); + } + } + for (const auto &type : meta_info.edge_type) { + auto edge_info = response->add_edge_info(); + edge_info->set_type(static_cast(type)); + auto itr = meta_info.edge_num.find(type); + if (itr != meta_info.edge_num.end()) { + edge_info->set_num(static_cast(itr->second)); + } else { + edge_info->set_num(0); + } + } + for (const auto &type : meta_info.node_feature_type) { + response->add_node_feature_type(static_cast(type)); + } + for (const auto &type : meta_info.edge_feature_type) { + response->add_edge_feature_type(static_cast(type)); + } + } else { + response->set_error_msg(s.ToString()); + } + return ::grpc::Status::OK; +} + +Status GraphDataServiceImpl::GetAllNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { + CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of edge types is not 1"); + + std::shared_ptr tensor; + RETURN_IF_NOT_OK(graph_data_impl_->GetAllNodes(static_cast(request->type()[0]), &tensor)); + TensorPb *result = response->add_result_data(); + RETURN_IF_NOT_OK(TensorToPb(tensor, result)); + return Status::OK(); +} + +Status GraphDataServiceImpl::GetAllEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { + CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of edge types is not 1"); + + std::shared_ptr tensor; + RETURN_IF_NOT_OK(graph_data_impl_->GetAllEdges(static_cast(request->type()[0]), &tensor)); + TensorPb *result = response->add_result_data(); + RETURN_IF_NOT_OK(TensorToPb(tensor, result)); + return Status::OK(); +} + +Status GraphDataServiceImpl::GetNodesFromEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { + CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input edge id is empty"); + + std::vector edge_list; + edge_list.resize(request->id().size()); + std::transform(request->id().begin(), request->id().end(), edge_list.begin(), + [](const google::protobuf::int32 id) { return static_cast(id); }); + std::shared_ptr tensor; + RETURN_IF_NOT_OK(graph_data_impl_->GetNodesFromEdges(edge_list, &tensor)); + TensorPb *result = response->add_result_data(); + RETURN_IF_NOT_OK(TensorToPb(tensor, result)); + return Status::OK(); +} + +Status GraphDataServiceImpl::GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { + CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty"); + CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of edge types is not 1"); + + std::vector node_list; + node_list.resize(request->id().size()); + std::transform(request->id().begin(), request->id().end(), node_list.begin(), + [](const google::protobuf::int32 id) { return static_cast(id); }); + std::shared_ptr tensor; + RETURN_IF_NOT_OK(graph_data_impl_->GetAllNeighbors(node_list, static_cast(request->type()[0]), &tensor)); + TensorPb *result = response->add_result_data(); + RETURN_IF_NOT_OK(TensorToPb(tensor, result)); + return Status::OK(); +} + +Status GraphDataServiceImpl::GetSampledNeighbors(const GnnGraphDataRequestPb *request, + GnnGraphDataResponsePb *response) { + CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty"); + CHECK_FAIL_RETURN_UNEXPECTED(request->number_size() > 0, "The input neighbor number is empty"); + CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() > 0, "The input neighbor type is empty"); + + std::vector node_list; + node_list.resize(request->id().size()); + std::transform(request->id().begin(), request->id().end(), node_list.begin(), + [](const google::protobuf::int32 id) { return static_cast(id); }); + std::vector neighbor_nums; + neighbor_nums.resize(request->number().size()); + std::transform(request->number().begin(), request->number().end(), neighbor_nums.begin(), + [](const google::protobuf::int32 num) { return static_cast(num); }); + std::vector neighbor_types; + neighbor_types.resize(request->type().size()); + std::transform(request->type().begin(), request->type().end(), neighbor_types.begin(), + [](const google::protobuf::int32 type) { return static_cast(type); }); + std::shared_ptr tensor; + RETURN_IF_NOT_OK(graph_data_impl_->GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &tensor)); + TensorPb *result = response->add_result_data(); + RETURN_IF_NOT_OK(TensorToPb(tensor, result)); + return Status::OK(); +} + +Status GraphDataServiceImpl::GetNegSampledNeighbors(const GnnGraphDataRequestPb *request, + GnnGraphDataResponsePb *response) { + CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty"); + CHECK_FAIL_RETURN_UNEXPECTED(request->number_size() == 1, "The number of neighbor number is not 1"); + CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of neighbor types is not 1"); + + std::vector node_list; + node_list.resize(request->id().size()); + std::transform(request->id().begin(), request->id().end(), node_list.begin(), + [](const google::protobuf::int32 id) { return static_cast(id); }); + std::shared_ptr tensor; + RETURN_IF_NOT_OK(graph_data_impl_->GetNegSampledNeighbors(node_list, static_cast(request->number()[0]), + static_cast(request->type()[0]), &tensor)); + TensorPb *result = response->add_result_data(); + RETURN_IF_NOT_OK(TensorToPb(tensor, result)); + return Status::OK(); +} + +Status GraphDataServiceImpl::RandomWalk(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { + CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty"); + CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() > 0, "The input meta path is empty"); + + std::vector node_list; + node_list.resize(request->id().size()); + std::transform(request->id().begin(), request->id().end(), node_list.begin(), + [](const google::protobuf::int32 id) { return static_cast(id); }); + std::vector meta_path; + meta_path.resize(request->type().size()); + std::transform(request->type().begin(), request->type().end(), meta_path.begin(), + [](const google::protobuf::int32 type) { return static_cast(type); }); + std::shared_ptr tensor; + RETURN_IF_NOT_OK(graph_data_impl_->RandomWalk(node_list, meta_path, request->random_walk().p(), + request->random_walk().q(), request->random_walk().default_id(), + &tensor)); + TensorPb *result = response->add_result_data(); + RETURN_IF_NOT_OK(TensorToPb(tensor, result)); + return Status::OK(); +} + +Status GraphDataServiceImpl::GetNodeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { + std::shared_ptr nodes; + RETURN_IF_NOT_OK(PbToTensor(&request->id_tensor(), &nodes)); + for (const auto &type : request->type()) { + std::shared_ptr tensor; + RETURN_IF_NOT_OK(graph_data_impl_->GetNodeFeatureSharedMemory(nodes, type, &tensor)); + TensorPb *result = response->add_result_data(); + RETURN_IF_NOT_OK(TensorToPb(tensor, result)); + } + return Status::OK(); +} + +Status GraphDataServiceImpl::GetEdgeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) { + std::shared_ptr edges; + RETURN_IF_NOT_OK(PbToTensor(&request->id_tensor(), &edges)); + for (const auto &type : request->type()) { + std::shared_ptr tensor; + RETURN_IF_NOT_OK(graph_data_impl_->GetEdgeFeatureSharedMemory(edges, type, &tensor)); + TensorPb *result = response->add_result_data(); + RETURN_IF_NOT_OK(TensorToPb(tensor, result)); + } + return Status::OK(); +} + +} // namespace gnn +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.h new file mode 100644 index 00000000000..74996ccae4a --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.h @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_ + +#include +#include + +#include "minddata/dataset/engine/gnn/graph_data_impl.h" +#include "proto/gnn_graph_data.grpc.pb.h" +#include "proto/gnn_graph_data.pb.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +class GraphDataServer; + +// class GraphDataServiceImpl : public GnnGraphData::Service { +class GraphDataServiceImpl { + public: + GraphDataServiceImpl(GraphDataServer *server, GraphDataImpl *graph_data_impl); + ~GraphDataServiceImpl() = default; + + grpc::Status ClientRegister(grpc::ServerContext *context, const GnnClientRegisterRequestPb *request, + GnnClientRegisterResponsePb *response); + + grpc::Status ClientUnRegister(grpc::ServerContext *context, const GnnClientUnRegisterRequestPb *request, + GnnClientUnRegisterResponsePb *response); + + grpc::Status GetGraphData(grpc::ServerContext *context, const GnnGraphDataRequestPb *request, + GnnGraphDataResponsePb *response); + + grpc::Status GetMetaInfo(grpc::ServerContext *context, const GnnMetaInfoRequestPb *request, + GnnMetaInfoResponsePb *response); + + Status GetAllNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); + Status GetAllEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); + Status GetNodesFromEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); + Status GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); + Status GetSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); + Status GetNegSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); + Status RandomWalk(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); + Status GetNodeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); + Status GetEdgeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response); + + private: + Status FillDefaultFeature(GnnClientRegisterResponsePb *response); + + GraphDataServer *server_; + GraphDataImpl *graph_data_impl_; +}; + +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_feature_parser.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_feature_parser.cc new file mode 100644 index 00000000000..f09bf8abe8c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_feature_parser.cc @@ -0,0 +1,106 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/gnn/graph_feature_parser.h" + +#include +#include + +#include "mindspore/ccsrc/minddata/mindrecord/include/shard_error.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +using mindrecord::MSRStatus; + +GraphFeatureParser::GraphFeatureParser(const ShardColumn &shard_column) { + shard_column_ = std::make_unique(shard_column); +} + +Status GraphFeatureParser::LoadFeatureTensor(const std::string &key, const std::vector &col_blob, + std::shared_ptr *tensor) { + const unsigned char *data = nullptr; + std::unique_ptr data_ptr; + uint64_t n_bytes = 0, col_type_size = 1; + mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; + std::vector column_shape; + MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, + &col_type_size, &column_shape); + CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key); + if (data == nullptr) data = reinterpret_cast(&data_ptr[0]); + RETURN_IF_NOT_OK(Tensor::CreateFromMemory(std::move(TensorShape({static_cast(n_bytes / col_type_size)})), + std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])), + data, tensor)); + return Status::OK(); +} + +#if !defined(_WIN32) && !defined(_WIN64) +Status GraphFeatureParser::LoadFeatureToSharedMemory(const std::string &key, const std::vector &col_blob, + GraphSharedMemory *shared_memory, + std::shared_ptr *out_tensor) { + const unsigned char *data = nullptr; + std::unique_ptr data_ptr; + uint64_t n_bytes = 0, col_type_size = 1; + mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; + std::vector column_shape; + MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, + &col_type_size, &column_shape); + CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key); + if (data == nullptr) data = reinterpret_cast(&data_ptr[0]); + std::shared_ptr tensor; + RETURN_IF_NOT_OK(Tensor::CreateEmpty(std::move(TensorShape({2})), std::move(DataType(DataType::DE_INT64)), &tensor)); + auto fea_itr = tensor->begin(); + int64_t offset = 0; + RETURN_IF_NOT_OK(shared_memory->InsertData(data, n_bytes, &offset)); + *fea_itr = offset; + ++fea_itr; + *fea_itr = n_bytes; + *out_tensor = std::move(tensor); + return Status::OK(); +} +#endif + +Status GraphFeatureParser::LoadFeatureIndex(const std::string &key, const std::vector &col_blob, + std::vector *indices) { + const unsigned char *data = nullptr; + std::unique_ptr data_ptr; + uint64_t n_bytes = 0, col_type_size = 1; + mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; + std::vector column_shape; + MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type, + &col_type_size, &column_shape); + CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key); + + if (data == nullptr) data = reinterpret_cast(&data_ptr[0]); + + for (int i = 0; i < n_bytes; i += col_type_size) { + int32_t feature_ind = -1; + if (col_type == mindrecord::ColumnInt32) { + feature_ind = *(reinterpret_cast(data + i)); + } else if (col_type == mindrecord::ColumnInt64) { + feature_ind = *(reinterpret_cast(data + i)); + } else { + RETURN_STATUS_UNEXPECTED("Feature Index needs to be int32/int64 type!"); + } + if (feature_ind >= 0) indices->push_back(feature_ind); + } + return Status::OK(); +} + +} // namespace gnn +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_feature_parser.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_feature_parser.h new file mode 100644 index 00000000000..e84b758b330 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_feature_parser.h @@ -0,0 +1,67 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_ +#include +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/core/tensor.h" +#if !defined(_WIN32) && !defined(_WIN64) +#include "minddata/dataset/engine/gnn/graph_shared_memory.h" +#endif +#include "minddata/dataset/engine/gnn/feature.h" +#include "minddata/dataset/util/status.h" +#include "minddata/mindrecord/include/shard_column.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +using mindrecord::ShardColumn; + +class GraphFeatureParser { + public: + explicit GraphFeatureParser(const ShardColumn &shard_column); + + ~GraphFeatureParser() = default; + + // @param std::string key - column name + // @param std::vector &blob - contains data in blob field in mindrecord + // @param std::vector *ind - return value, list of feature index in int32_t + // @return Status - the status code + Status LoadFeatureIndex(const std::string &key, const std::vector &blob, std::vector *ind); + + // @param std::string &key - column name + // @param std::vector &blob - contains data in blob field in mindrecord + // @param std::shared_ptr *tensor - return value feature tensor + // @return Status - the status code + Status LoadFeatureTensor(const std::string &key, const std::vector &blob, std::shared_ptr *tensor); +#if !defined(_WIN32) && !defined(_WIN64) + Status LoadFeatureToSharedMemory(const std::string &key, const std::vector &col_blob, + GraphSharedMemory *shared_memory, std::shared_ptr *out_tensor); +#endif + private: + std::unique_ptr shard_column_; +}; +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc index 2339b02de21..1d043ced745 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.cc @@ -13,41 +13,42 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include "minddata/dataset/engine/gnn/graph_loader.h" #include #include #include -#include "minddata/dataset/engine/gnn/graph_loader.h" -#include "mindspore/ccsrc/minddata/mindrecord/include/shard_error.h" +#include "minddata/dataset/engine/gnn/graph_data_impl.h" #include "minddata/dataset/engine/gnn/local_edge.h" #include "minddata/dataset/engine/gnn/local_node.h" #include "minddata/dataset/util/task_manager.h" +#include "minddata/mindrecord/include/shard_error.h" using ShardTuple = std::vector, mindspore::mindrecord::json>>; - namespace mindspore { namespace dataset { namespace gnn { using mindrecord::MSRStatus; -GraphLoader::GraphLoader(std::string mr_filepath, int32_t num_workers) - : mr_path_(mr_filepath), +GraphLoader::GraphLoader(GraphDataImpl *graph_impl, std::string mr_filepath, int32_t num_workers, bool server_mode) + : graph_impl_(graph_impl), + mr_path_(mr_filepath), num_workers_(num_workers), row_id_(0), shard_reader_(nullptr), + graph_feature_parser_(nullptr), keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {} -Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map, - EdgeTypeMap *e_type_map, NodeFeatureMap *n_feature_map, - EdgeFeatureMap *e_feature_map, DefaultNodeFeatureMap *default_node_feature_map, - DefaultEdgeFeatureMap *default_edge_feature_map) { +Status GraphLoader::GetNodesAndEdges() { + NodeIdMap *n_id_map = &graph_impl_->node_id_map_; + EdgeIdMap *e_id_map = &graph_impl_->edge_id_map_; for (std::deque> &dq : n_deques_) { while (dq.empty() == false) { std::shared_ptr node_ptr = dq.front(); n_id_map->insert({node_ptr->id(), node_ptr}); - (*n_type_map)[node_ptr->type()].push_back(node_ptr->id()); + graph_impl_->node_type_map_[node_ptr->type()].push_back(node_ptr->id()); dq.pop_front(); } } @@ -63,15 +64,15 @@ Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, N RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second})); RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second)); e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_ - (*e_type_map)[edge_ptr->type()].push_back(edge_ptr->id()); + graph_impl_->edge_type_map_[edge_ptr->type()].push_back(edge_ptr->id()); dq.pop_front(); } } - for (auto &itr : *n_type_map) itr.second.shrink_to_fit(); - for (auto &itr : *e_type_map) itr.second.shrink_to_fit(); + for (auto &itr : graph_impl_->node_type_map_) itr.second.shrink_to_fit(); + for (auto &itr : graph_impl_->edge_type_map_) itr.second.shrink_to_fit(); - MergeFeatureMaps(n_feature_map, e_feature_map, default_node_feature_map, default_edge_feature_map); + MergeFeatureMaps(); return Status::OK(); } @@ -92,13 +93,26 @@ Status GraphLoader::InitAndLoad() { CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetShardHeader()->GetSchemaCount() > 0, "No schema found!"); CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Launch(true) == MSRStatus::SUCCESS, "fail to launch mr"); - mindrecord::json schema = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema())["schema"]; + graph_impl_->data_schema_ = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema()); + mindrecord::json schema = graph_impl_->data_schema_["schema"]; for (const std::string &key : keys_) { if (schema.find(key) == schema.end()) { RETURN_STATUS_UNEXPECTED(key + ":doesn't exist in schema:" + schema.dump()); } } + if (graph_impl_->server_mode_) { +#if !defined(_WIN32) && !defined(_WIN64) + int64_t total_blob_size = 0; + CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetTotalBlobSize(&total_blob_size) == MSRStatus::SUCCESS, + "failed to get total blob size"); + graph_impl_->graph_shared_memory_ = std::make_unique(total_blob_size, mr_path_); + RETURN_IF_NOT_OK(graph_impl_->graph_shared_memory_->CreateSharedMemory()); +#endif + } + + graph_feature_parser_ = std::make_unique(*shard_reader_->GetShardColumn()); + // launching worker threads for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) { RETURN_IF_NOT_OK(vg.CreateAsyncTask("GraphLoader", std::bind(&GraphLoader::WorkerEntry, this, wkr_id))); @@ -116,18 +130,39 @@ Status GraphLoader::LoadNode(const std::vector &col_blob, const mindrec NodeType node_type = static_cast(col_jsn["type"]); (*node) = std::make_shared(node_id, node_type); std::vector indices; - RETURN_IF_NOT_OK(LoadFeatureIndex("node_feature_index", col_blob, col_jsn, &indices)); - - for (int32_t ind : indices) { - std::shared_ptr tensor; - RETURN_IF_NOT_OK(LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor)); - RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared(ind, tensor))); - (*feature_map)[node_type].insert(ind); - if ((*default_feature)[ind] == nullptr) { - std::shared_ptr zero_tensor; - RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor)); - RETURN_IF_NOT_OK(zero_tensor->Zero()); - (*default_feature)[ind] = std::make_shared(ind, zero_tensor); + RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("node_feature_index", col_blob, &indices)); + if (graph_impl_->server_mode_) { +#if !defined(_WIN32) && !defined(_WIN64) + for (int32_t ind : indices) { + std::shared_ptr tensor_sm; + RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureToSharedMemory( + "node_feature_" + std::to_string(ind), col_blob, graph_impl_->graph_shared_memory_.get(), &tensor_sm)); + RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared(ind, tensor_sm, true))); + (*feature_map)[node_type].insert(ind); + if ((*default_feature)[ind] == nullptr) { + std::shared_ptr tensor; + RETURN_IF_NOT_OK( + graph_feature_parser_->LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, &tensor)); + std::shared_ptr zero_tensor; + RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor)); + RETURN_IF_NOT_OK(zero_tensor->Zero()); + (*default_feature)[ind] = std::make_shared(ind, zero_tensor); + } + } +#endif + } else { + for (int32_t ind : indices) { + std::shared_ptr tensor; + RETURN_IF_NOT_OK( + graph_feature_parser_->LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, &tensor)); + RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared(ind, tensor))); + (*feature_map)[node_type].insert(ind); + if ((*default_feature)[ind] == nullptr) { + std::shared_ptr zero_tensor; + RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor)); + RETURN_IF_NOT_OK(zero_tensor->Zero()); + (*default_feature)[ind] = std::make_shared(ind, zero_tensor); + } } } return Status::OK(); @@ -143,63 +178,42 @@ Status GraphLoader::LoadEdge(const std::vector &col_blob, const mindrec std::shared_ptr dst = std::make_shared(dst_id, -1); (*edge) = std::make_shared(edge_id, edge_type, src, dst); std::vector indices; - RETURN_IF_NOT_OK(LoadFeatureIndex("edge_feature_index", col_blob, col_jsn, &indices)); - for (int32_t ind : indices) { - std::shared_ptr tensor; - RETURN_IF_NOT_OK(LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor)); - RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared(ind, tensor))); - (*feature_map)[edge_type].insert(ind); - if ((*default_feature)[ind] == nullptr) { - std::shared_ptr zero_tensor; - RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor)); - RETURN_IF_NOT_OK(zero_tensor->Zero()); - (*default_feature)[ind] = std::make_shared(ind, zero_tensor); + RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("edge_feature_index", col_blob, &indices)); + if (graph_impl_->server_mode_) { +#if !defined(_WIN32) && !defined(_WIN64) + for (int32_t ind : indices) { + std::shared_ptr tensor_sm; + RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureToSharedMemory( + "edge_feature_" + std::to_string(ind), col_blob, graph_impl_->graph_shared_memory_.get(), &tensor_sm)); + RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared(ind, tensor_sm, true))); + (*feature_map)[edge_type].insert(ind); + if ((*default_feature)[ind] == nullptr) { + std::shared_ptr tensor; + RETURN_IF_NOT_OK( + graph_feature_parser_->LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, &tensor)); + std::shared_ptr zero_tensor; + RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor)); + RETURN_IF_NOT_OK(zero_tensor->Zero()); + (*default_feature)[ind] = std::make_shared(ind, zero_tensor); + } + } +#endif + } else { + for (int32_t ind : indices) { + std::shared_ptr tensor; + RETURN_IF_NOT_OK( + graph_feature_parser_->LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, &tensor)); + RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared(ind, tensor))); + (*feature_map)[edge_type].insert(ind); + if ((*default_feature)[ind] == nullptr) { + std::shared_ptr zero_tensor; + RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor)); + RETURN_IF_NOT_OK(zero_tensor->Zero()); + (*default_feature)[ind] = std::make_shared(ind, zero_tensor); + } } } - return Status::OK(); -} -Status GraphLoader::LoadFeatureTensor(const std::string &key, const std::vector &col_blob, - const mindrecord::json &col_jsn, std::shared_ptr *tensor) { - const unsigned char *data = nullptr; - std::unique_ptr data_ptr; - uint64_t n_bytes = 0, col_type_size = 1; - mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; - std::vector column_shape; - MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName( - key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape); - CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key); - if (data == nullptr) data = reinterpret_cast(&data_ptr[0]); - RETURN_IF_NOT_OK(Tensor::CreateFromMemory(std::move(TensorShape({static_cast(n_bytes / col_type_size)})), - std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])), - data, tensor)); - return Status::OK(); -} - -Status GraphLoader::LoadFeatureIndex(const std::string &key, const std::vector &col_blob, - const mindrecord::json &col_jsn, std::vector *indices) { - const unsigned char *data = nullptr; - std::unique_ptr data_ptr; - uint64_t n_bytes = 0, col_type_size = 1; - mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType; - std::vector column_shape; - MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName( - key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape); - CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key); - - if (data == nullptr) data = reinterpret_cast(&data_ptr[0]); - - for (int i = 0; i < n_bytes; i += col_type_size) { - int32_t feature_ind = -1; - if (col_type == mindrecord::ColumnInt32) { - feature_ind = *(reinterpret_cast(data + i)); - } else if (col_type == mindrecord::ColumnInt64) { - feature_ind = *(reinterpret_cast(data + i)); - } else { - RETURN_STATUS_UNEXPECTED("Feature Index needs to be int32/int64 type!"); - } - if (feature_ind >= 0) indices->push_back(feature_ind); - } return Status::OK(); } @@ -234,21 +248,19 @@ Status GraphLoader::WorkerEntry(int32_t worker_id) { return Status::OK(); } -void GraphLoader::MergeFeatureMaps(NodeFeatureMap *n_feature_map, EdgeFeatureMap *e_feature_map, - DefaultNodeFeatureMap *default_node_feature_map, - DefaultEdgeFeatureMap *default_edge_feature_map) { +void GraphLoader::MergeFeatureMaps() { for (int wkr_id = 0; wkr_id < num_workers_; wkr_id++) { for (auto &m : n_feature_maps_[wkr_id]) { - for (auto &n : m.second) (*n_feature_map)[m.first].insert(n); + for (auto &n : m.second) graph_impl_->node_feature_map_[m.first].insert(n); } for (auto &m : e_feature_maps_[wkr_id]) { - for (auto &n : m.second) (*e_feature_map)[m.first].insert(n); + for (auto &n : m.second) graph_impl_->edge_feature_map_[m.first].insert(n); } for (auto &m : default_node_feature_maps_[wkr_id]) { - (*default_node_feature_map)[m.first] = m.second; + graph_impl_->default_node_feature_map_[m.first] = m.second; } for (auto &m : default_edge_feature_maps_[wkr_id]) { - (*default_edge_feature_map)[m.first] = m.second; + graph_impl_->default_edge_feature_map_[m.first] = m.second; } } n_feature_maps_.clear(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.h index e59b13837cd..58320861f2d 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_loader.h @@ -26,10 +26,13 @@ #include "minddata/dataset/core/data_type.h" #include "minddata/dataset/core/tensor.h" -#include "minddata/dataset/engine/gnn/feature.h" -#include "minddata/dataset/engine/gnn/graph.h" -#include "minddata/dataset/engine/gnn/node.h" #include "minddata/dataset/engine/gnn/edge.h" +#include "minddata/dataset/engine/gnn/feature.h" +#include "minddata/dataset/engine/gnn/graph_feature_parser.h" +#if !defined(_WIN32) && !defined(_WIN64) +#include "minddata/dataset/engine/gnn/graph_shared_memory.h" +#endif +#include "minddata/dataset/engine/gnn/node.h" #include "minddata/dataset/util/status.h" #include "minddata/mindrecord/include/shard_reader.h" namespace mindspore { @@ -46,13 +49,15 @@ using EdgeFeatureMap = std::unordered_map>; using DefaultEdgeFeatureMap = std::unordered_map>; +class GraphDataImpl; + // this class interfaces with the underlying storage format (mindrecord) // it returns raw nodes and edges via GetNodesAndEdges // it is then the responsibility of graph to construct itself based on the nodes and edges // if needed, this class could become a base where each derived class handles a specific storage format class GraphLoader { public: - explicit GraphLoader(std::string mr_filepath, int32_t num_workers = 4); + GraphLoader(GraphDataImpl *graph_impl, std::string mr_filepath, int32_t num_workers = 4, bool server_mode = false); ~GraphLoader() = default; // Init mindrecord and load everything into memory multi-threaded @@ -63,8 +68,7 @@ class GraphLoader { // nodes and edges are added to map without any connection. That's because there nodes and edges are read in // random order. src_node and dst_node in Edge are node_id only with -1 as type. // features attached to each node and edge are expected to be filled correctly - Status GetNodesAndEdges(NodeIdMap *, EdgeIdMap *, NodeTypeMap *, EdgeTypeMap *, NodeFeatureMap *, EdgeFeatureMap *, - DefaultNodeFeatureMap *, DefaultEdgeFeatureMap *); + Status GetNodesAndEdges(); private: // @@ -92,29 +96,15 @@ class GraphLoader { Status LoadEdge(const std::vector &blob, const mindrecord::json &jsn, std::shared_ptr *edge, EdgeFeatureMap *feature_map, DefaultEdgeFeatureMap *default_feature); - // @param std::string key - column name - // @param std::vector &blob - contains data in blob field in mindrecord - // @param mindrecord::json &jsn - contains raw data - // @param std::vector *ind - return value, list of feature index in int32_t - // @return Status - the status code - Status LoadFeatureIndex(const std::string &key, const std::vector &blob, const mindrecord::json &jsn, - std::vector *ind); - - // @param std::string &key - column name - // @param std::vector &blob - contains data in blob field in mindrecord - // @param mindrecord::json &jsn - contains raw data - // @param std::shared_ptr *tensor - return value feature tensor - // @return Status - the status code - Status LoadFeatureTensor(const std::string &key, const std::vector &blob, const mindrecord::json &jsn, - std::shared_ptr *tensor); - // merge NodeFeatureMap and EdgeFeatureMap of each worker into 1 - void MergeFeatureMaps(NodeFeatureMap *, EdgeFeatureMap *, DefaultNodeFeatureMap *, DefaultEdgeFeatureMap *); + void MergeFeatureMaps(); + GraphDataImpl *graph_impl_; + std::string mr_path_; const int32_t num_workers_; std::atomic_int row_id_; - std::string mr_path_; std::unique_ptr shard_reader_; + std::unique_ptr graph_feature_parser_; std::vector>> n_deques_; std::vector>> e_deques_; std::vector n_feature_maps_; diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_shared_memory.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_shared_memory.cc new file mode 100644 index 00000000000..54e6eda0d28 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_shared_memory.cc @@ -0,0 +1,134 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/engine/gnn/graph_shared_memory.h" + +#include + +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +GraphSharedMemory::GraphSharedMemory(int64_t memory_size, key_t memory_key) + : memory_size_(memory_size), + memory_key_(memory_key), + memory_ptr_(nullptr), + memory_offset_(0), + is_new_create_(false) { + std::stringstream stream; + stream << std::hex << memory_key_; + memory_key_str_ = stream.str(); +} + +GraphSharedMemory::GraphSharedMemory(int64_t memory_size, const std::string &mr_file) + : mr_file_(mr_file), + memory_size_(memory_size), + memory_key_(-1), + memory_ptr_(nullptr), + memory_offset_(0), + is_new_create_(false) {} + +GraphSharedMemory::~GraphSharedMemory() { + if (is_new_create_) { + (void)DeleteSharedMemory(); + } +} + +Status GraphSharedMemory::CreateSharedMemory() { + if (memory_key_ == -1) { + // ftok to generate unique key + memory_key_ = ftok(mr_file_.data(), kGnnSharedMemoryId); + CHECK_FAIL_RETURN_UNEXPECTED(memory_key_ != -1, "Failed to get key of shared memory. file_name:" + mr_file_); + std::stringstream stream; + stream << std::hex << memory_key_; + memory_key_str_ = stream.str(); + } + int shmflg = (0666 | IPC_CREAT | IPC_EXCL); + Status s = SharedMemoryImpl(shmflg); + if (s.IsOk()) { + is_new_create_ = true; + MS_LOG(INFO) << "Create shared memory success, key=0x" << memory_key_str_; + } else { + MS_LOG(WARNING) << "Shared memory with the same key may already exist, key=0x" << memory_key_str_; + shmflg = (0666 | IPC_CREAT); + s = SharedMemoryImpl(shmflg); + if (!s.IsOk()) { + RETURN_STATUS_UNEXPECTED("Create shared memory fao;ed, key=0x" + memory_key_str_); + } + } + return Status::OK(); +} + +Status GraphSharedMemory::GetSharedMemory() { + int shmflg = 0; + RETURN_IF_NOT_OK(SharedMemoryImpl(shmflg)); + return Status::OK(); +} + +Status GraphSharedMemory::DeleteSharedMemory() { + int shmid = shmget(memory_key_, 0, 0); + CHECK_FAIL_RETURN_UNEXPECTED(shmid != -1, "Failed to get shared memory. key=0x" + memory_key_str_); + int result = shmctl(shmid, IPC_RMID, 0); + CHECK_FAIL_RETURN_UNEXPECTED(result != -1, "Failed to delete shared memory. key=0x" + memory_key_str_); + return Status::OK(); +} + +Status GraphSharedMemory::SharedMemoryImpl(const int &shmflg) { + // shmget returns an identifier in shmid + int shmid = shmget(memory_key_, memory_size_, shmflg); + CHECK_FAIL_RETURN_UNEXPECTED(shmid != -1, "Failed to get shared memory. key=0x" + memory_key_str_); + + // shmat to attach to shared memory + auto data = shmat(shmid, reinterpret_cast(0), 0); + CHECK_FAIL_RETURN_UNEXPECTED(data != (char *)(-1), "Failed to address shared memory. key=0x" + memory_key_str_); + memory_ptr_ = reinterpret_cast(data); + + return Status::OK(); +} + +Status GraphSharedMemory::InsertData(const uint8_t *data, int64_t len, int64_t *offset) { + CHECK_FAIL_RETURN_UNEXPECTED(data, "Input data is nullptr."); + CHECK_FAIL_RETURN_UNEXPECTED(len > 0, "Input len is invalid."); + + std::lock_guard lck(mutex_); + CHECK_FAIL_RETURN_UNEXPECTED((memory_size_ - memory_offset_ >= len), + "Insufficient shared memory space to insert data."); + if (EOK != memcpy_s(memory_ptr_ + memory_offset_, memory_size_ - memory_offset_, data, len)) { + RETURN_STATUS_UNEXPECTED("Failed to insert data into shared memory."); + } + *offset = memory_offset_; + memory_offset_ += len; + return Status::OK(); +} + +Status GraphSharedMemory::GetData(uint8_t *data, int64_t data_len, int64_t offset, int64_t get_data_len) { + CHECK_FAIL_RETURN_UNEXPECTED(data, "Input data is nullptr."); + CHECK_FAIL_RETURN_UNEXPECTED(get_data_len > 0, "Input get_data_len is invalid."); + CHECK_FAIL_RETURN_UNEXPECTED(data_len >= get_data_len, "Insufficient target address space."); + + CHECK_FAIL_RETURN_UNEXPECTED(memory_size_ >= get_data_len + offset, + "get_data_len is too large, beyond the space of shared memory."); + if (EOK != memcpy_s(data, data_len, memory_ptr_ + offset, get_data_len)) { + RETURN_STATUS_UNEXPECTED("Failed to insert data into shared memory."); + } + return Status::OK(); +} + +} // namespace gnn +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_shared_memory.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_shared_memory.h new file mode 100644 index 00000000000..2b94c8c3000 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_shared_memory.h @@ -0,0 +1,72 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_ + +#include +#include +#include +#include + +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +namespace gnn { + +const int kGnnSharedMemoryId = 65; + +class GraphSharedMemory { + public: + explicit GraphSharedMemory(int64_t memory_size, key_t memory_key); + explicit GraphSharedMemory(int64_t memory_size, const std::string &mr_file); + + ~GraphSharedMemory(); + + // @param uint8_t** shared_memory - shared memory address + // @return Status - the status code + Status CreateSharedMemory(); + + // @param uint8_t** shared_memory - shared memory address + // @return Status - the status code + Status GetSharedMemory(); + + Status DeleteSharedMemory(); + + Status InsertData(const uint8_t *data, int64_t len, int64_t *offset); + + Status GetData(uint8_t *data, int64_t data_len, int64_t offset, int64_t get_data_len); + + key_t memory_key() { return memory_key_; } + + int64_t memory_size() { return memory_size_; } + + private: + Status SharedMemoryImpl(const int &shmflg); + + std::string mr_file_; + int64_t memory_size_; + key_t memory_key_; + std::string memory_key_str_; + uint8_t *memory_ptr_; + int64_t memory_offset_; + std::mutex mutex_; + bool is_new_create_; +}; +} // namespace gnn +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/grpc_async_server.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/grpc_async_server.cc new file mode 100644 index 00000000000..89000e973ee --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/grpc_async_server.cc @@ -0,0 +1,82 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "minddata/dataset/engine/gnn/grpc_async_server.h" + +#include + +#include "minddata/dataset/util/task_manager.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { + +GrpcAsyncServer::GrpcAsyncServer(const std::string &host, int32_t port) : host_(host), port_(port) {} + +GrpcAsyncServer::~GrpcAsyncServer() { Stop(); } + +Status GrpcAsyncServer::Run() { + std::string server_address = host_ + ":" + std::to_string(port_); + grpc::ServerBuilder builder; + // Default message size for gRPC is 4MB. Increase it to 2g-1 + builder.SetMaxReceiveMessageSize(std::numeric_limits::max()); + builder.AddChannelArgument(GRPC_ARG_ALLOW_REUSEPORT, 0); + int port_tcpip = 0; + builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_tcpip); + RETURN_IF_NOT_OK(RegisterService(&builder)); + cq_ = builder.AddCompletionQueue(); + server_ = builder.BuildAndStart(); + if (server_) { + MS_LOG(INFO) << "Server listening on " << server_address; + } else { + std::string errMsg = "Fail to start server. "; + if (port_tcpip != port_) { + errMsg += "Unable to bind to address " + server_address + "."; + } + RETURN_STATUS_UNEXPECTED(errMsg); + } + return Status::OK(); +} + +Status GrpcAsyncServer::HandleRequest() { + bool success; + void *tag; + // We loop through the grpc queue. Each connection if successful + // will come back with our own tag which is an instance of CallData + // and we simply call its functor. But first we need to create these instances + // and inject them into the grpc queue. + RETURN_IF_NOT_OK(EnqueueRequest()); + while (cq_->Next(&tag, &success)) { + RETURN_IF_INTERRUPTED(); + if (success) { + RETURN_IF_NOT_OK(ProcessRequest(tag)); + } else { + MS_LOG(DEBUG) << "cq_->Next failed."; + } + } + return Status::OK(); +} + +void GrpcAsyncServer::Stop() { + if (server_) { + server_->Shutdown(); + } + // Always shutdown the completion queue after the server. + if (cq_) { + cq_->Shutdown(); + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/grpc_async_server.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/grpc_async_server.h new file mode 100644 index 00000000000..8023c23d7cf --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/grpc_async_server.h @@ -0,0 +1,59 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_ + +#include +#include +#include +#include + +#include "grpcpp/grpcpp.h" +#include "grpcpp/impl/codegen/async_unary_call.h" + +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +/// \brief Async server base class +class GrpcAsyncServer { + public: + explicit GrpcAsyncServer(const std::string &host, int32_t port); + virtual ~GrpcAsyncServer(); + /// \brief Brings up gRPC server + /// \return none + Status Run(); + /// \brief Entry function to handle async server request + Status HandleRequest(); + + void Stop(); + + virtual Status RegisterService(grpc::ServerBuilder *builder) = 0; + + virtual Status EnqueueRequest() = 0; + + virtual Status ProcessRequest(void *tag) = 0; + + protected: + int32_t port_; + std::string host_; + std::unique_ptr cq_; + std::unique_ptr server_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_ diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.cc index 642c73eed3a..d20be6e318e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.cc @@ -44,6 +44,7 @@ Status LocalEdge::UpdateFeature(const std::shared_ptr &feature) { return Status::OK(); } } + } // namespace gnn } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h index e9c7ba7f0ec..23dfd03694f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_edge.h @@ -20,10 +20,10 @@ #include #include -#include "minddata/dataset/util/status.h" #include "minddata/dataset/engine/gnn/edge.h" #include "minddata/dataset/engine/gnn/feature.h" #include "minddata/dataset/engine/gnn/node.h" +#include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h index 350797ac75f..7f5674b1867 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/local_node.h @@ -20,9 +20,9 @@ #include #include -#include "minddata/dataset/util/status.h" #include "minddata/dataset/engine/gnn/node.h" #include "minddata/dataset/engine/gnn/feature.h" +#include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/node.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/node.h index c89bb0e9056..e2659891044 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/node.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/node.h @@ -20,8 +20,8 @@ #include #include -#include "minddata/dataset/util/status.h" #include "minddata/dataset/engine/gnn/feature.h" +#include "minddata/dataset/util/status.h" namespace mindspore { namespace dataset { diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/tensor_proto.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/tensor_proto.cc new file mode 100644 index 00000000000..c2dd41bc073 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/tensor_proto.cc @@ -0,0 +1,84 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/engine/gnn/tensor_proto.h" + +#include +#include +#include + +namespace mindspore { +namespace dataset { + +const std::unordered_map g_pb2datatype_map{ + {DataTypePb::DE_PB_UNKNOWN, DataType::DE_UNKNOWN}, {DataTypePb::DE_PB_BOOL, DataType::DE_BOOL}, + {DataTypePb::DE_PB_INT8, DataType::DE_INT8}, {DataTypePb::DE_PB_UINT8, DataType::DE_UINT8}, + {DataTypePb::DE_PB_INT16, DataType::DE_INT16}, {DataTypePb::DE_PB_UINT16, DataType::DE_UINT16}, + {DataTypePb::DE_PB_INT32, DataType::DE_INT32}, {DataTypePb::DE_PB_UINT32, DataType::DE_UINT32}, + {DataTypePb::DE_PB_INT64, DataType::DE_INT64}, {DataTypePb::DE_PB_UINT64, DataType::DE_UINT64}, + {DataTypePb::DE_PB_FLOAT16, DataType::DE_FLOAT16}, {DataTypePb::DE_PB_FLOAT32, DataType::DE_FLOAT32}, + {DataTypePb::DE_PB_FLOAT64, DataType::DE_FLOAT64}, {DataTypePb::DE_PB_STRING, DataType::DE_STRING}, +}; + +const std::unordered_map g_datatype2pb_map{ + {DataType::DE_UNKNOWN, DataTypePb::DE_PB_UNKNOWN}, {DataType::DE_BOOL, DataTypePb::DE_PB_BOOL}, + {DataType::DE_INT8, DataTypePb::DE_PB_INT8}, {DataType::DE_UINT8, DataTypePb::DE_PB_UINT8}, + {DataType::DE_INT16, DataTypePb::DE_PB_INT16}, {DataType::DE_UINT16, DataTypePb::DE_PB_UINT16}, + {DataType::DE_INT32, DataTypePb::DE_PB_INT32}, {DataType::DE_UINT32, DataTypePb::DE_PB_UINT32}, + {DataType::DE_INT64, DataTypePb::DE_PB_INT64}, {DataType::DE_UINT64, DataTypePb::DE_PB_UINT64}, + {DataType::DE_FLOAT16, DataTypePb::DE_PB_FLOAT16}, {DataType::DE_FLOAT32, DataTypePb::DE_PB_FLOAT32}, + {DataType::DE_FLOAT64, DataTypePb::DE_PB_FLOAT64}, {DataType::DE_STRING, DataTypePb::DE_PB_STRING}, +}; + +Status TensorToPb(const std::shared_ptr tensor, TensorPb *tensor_pb) { + CHECK_FAIL_RETURN_UNEXPECTED(tensor, "Parameter tensor is a null pointer"); + CHECK_FAIL_RETURN_UNEXPECTED(tensor_pb, "Parameter tensor_pb is a null pointer"); + + std::vector shape = tensor->shape().AsVector(); + for (auto dim : shape) { + tensor_pb->add_dims(static_cast(dim)); + } + auto iter = g_datatype2pb_map.find(tensor->type().value()); + if (iter == g_datatype2pb_map.end()) { + RETURN_STATUS_UNEXPECTED("Invalid tensor type: " + tensor->type().ToString()); + } + tensor_pb->set_tensor_type(iter->second); + tensor_pb->set_data(tensor->GetBuffer(), tensor->SizeInBytes()); + return Status::OK(); +} + +Status PbToTensor(const TensorPb *tensor_pb, std::shared_ptr *tensor) { + CHECK_FAIL_RETURN_UNEXPECTED(tensor_pb, "Parameter tensor_pb is a null pointer"); + CHECK_FAIL_RETURN_UNEXPECTED(tensor, "Parameter tensor is a null pointer"); + + std::vector shape; + shape.resize(tensor_pb->dims().size()); + std::transform(tensor_pb->dims().begin(), tensor_pb->dims().end(), shape.begin(), + [](const google::protobuf::int64 dim) { return static_cast(dim); }); + auto iter = g_pb2datatype_map.find(tensor_pb->tensor_type()); + if (iter == g_pb2datatype_map.end()) { + RETURN_STATUS_UNEXPECTED("Invalid Tensor_pb type: " + std::to_string(tensor_pb->tensor_type())); + } + DataType::Type type = iter->second; + std::shared_ptr tensor_out; + RETURN_IF_NOT_OK(Tensor::CreateFromMemory(TensorShape(shape), DataType(type), + reinterpret_cast(tensor_pb->data().data()), + tensor_pb->data().size(), &tensor_out)); + *tensor = std::move(tensor_out); + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/tensor_proto.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/tensor_proto.h new file mode 100644 index 00000000000..4ffc10b0647 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/tensor_proto.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_TENSOR_PROTO_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_TENSOR_PROTO_H_ + +#include +#include +#include + +#include "proto/gnn_tensor.pb.h" +#include "minddata/dataset/core/tensor.h" + +namespace mindspore { +namespace dataset { + +Status TensorToPb(const std::shared_ptr tensor, TensorPb *tensor_pb); + +Status PbToTensor(const TensorPb *tensor_pb, std::shared_ptr *tensor); + +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_TENSOR_PROTO_H_ diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_column.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_column.h index 9510eeed1c6..f3bd43c5d81 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_column.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_column.h @@ -61,6 +61,7 @@ const std::unordered_map ColumnDataTypeMap = { class ShardColumn { public: explicit ShardColumn(const std::shared_ptr &shard_header, bool compress_integer = true); + explicit ShardColumn(const json &schema_json, bool compress_integer = true); ~ShardColumn() = default; @@ -72,23 +73,29 @@ class ShardColumn { std::vector *column_shape); /// \brief compress blob - std::vector CompressBlob(const std::vector &blob); + std::vector CompressBlob(const std::vector &blob, int64_t *compression_size); /// \brief check if blob compressed bool CheckCompressBlob() const { return has_compress_blob_; } + /// \brief getter uint64_t GetNumBlobColumn() const { return num_blob_column_; } + /// \brief getter std::vector GetColumnName() { return column_name_; } + /// \brief getter std::vector GeColumnDataType() { return column_data_type_; } + /// \brief getter std::vector> GetColumnShape() { return column_shape_; } /// \brief get column value from blob MSRStatus GetColumnFromBlob(const std::string &column_name, const std::vector &columns_blob, const unsigned char **data, std::unique_ptr *data_ptr, uint64_t *const n_bytes); + + /// \brief get column type std::pair GetColumnTypeByName(const std::string &column_name, ColumnDataType *column_data_type, uint64_t *column_data_type_size, @@ -99,6 +106,9 @@ class ShardColumn { std::unique_ptr *data_ptr, uint64_t *n_bytes); private: + /// \brief intialization + void Init(const json &schema_json, bool compress_integer = true); + /// \brief get float value from json template MSRStatus GetFloat(std::unique_ptr *data_ptr, const json &json_column_value, bool use_double); diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_header.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_header.h index 51928d7874e..33999ef5bd0 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_header.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_header.h @@ -65,6 +65,11 @@ class ShardHeader { /// \return the Statistic std::vector> GetStatistics(); + /// \brief add the statistic and save it + /// \param[in] statistic info of slim size + /// \return null + int64_t GetSlimSizeStatistic(const json &slim_size_json); + /// \brief get the fields of the index /// \return the fields of the index std::vector> GetFields(); @@ -114,10 +119,14 @@ class ShardHeader { uint64_t GetPageSize() const { return page_size_; } + uint64_t GetCompressionSize() const { return compression_size_; } + void SetHeaderSize(const uint64_t &header_size) { header_size_ = header_size; } void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; } + void SetCompressionSize(const uint64_t &compression_size) { compression_size_ = compression_size; } + std::vector SerializeHeader(); MSRStatus PagesToFile(const std::string dump_file_name); @@ -177,6 +186,7 @@ class ShardHeader { uint32_t shard_count_; uint64_t header_size_; uint64_t page_size_; + uint64_t compression_size_; std::shared_ptr index_; std::vector shard_addresses_; diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h index 6f185d5a4e4..13607aebe3d 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h @@ -209,6 +209,9 @@ class ShardReader { /// \brief get all classes MSRStatus GetAllClasses(const std::string &category_field, std::set &categories); + /// \brief get the size of blob data + MSRStatus GetTotalBlobSize(int64_t *total_blob_size); + protected: /// \brief sqlite call back function static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names); @@ -323,6 +326,7 @@ class ShardReader { const std::string kThreadName = "THRD_ITER_"; // prefix of thread name std::vector thread_set_; // thread list int num_rows_; // number of rows + int64_t total_blob_size_; // total size of blob data std::mutex mtx_delivery_; // locker for delivery std::condition_variable cv_delivery_; // conditional variable for delivery std::condition_variable cv_iterator_; // conditional variable for iterator diff --git a/mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h b/mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h index ddb7e7cb8fb..49b6712a668 100644 --- a/mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h +++ b/mindspore/ccsrc/minddata/mindrecord/include/shard_writer.h @@ -257,6 +257,7 @@ class ShardWriter { std::mutex check_mutex_; // mutex for data check std::atomic flag_{false}; + std::atomic compression_size_; }; } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc index 3aa40434f06..67fee2c13c9 100644 --- a/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc @@ -43,6 +43,7 @@ ShardReader::ShardReader() { page_size_ = 0; header_size_ = 0; num_rows_ = 0; + total_blob_size_ = 0; num_padded_ = 0; } @@ -55,9 +56,11 @@ std::pair> ShardReader::GetMeta(const std::s return {FAILED, {}}; } auto header = ret.second; - meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]}, - {"version", header["version"]}, {"index_fields", header["index_fields"]}, - {"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}}; + uint64_t compression_size = header.contains("compression_size") ? header["compression_size"].get() : 0; + meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]}, + {"compression_size", compression_size}, {"version", header["version"]}, + {"index_fields", header["index_fields"]}, {"schema", header["schema"]}, + {"blob_fields", header["blob_fields"]}}; return {SUCCESS, header["shard_addresses"]}; } @@ -145,6 +148,11 @@ MSRStatus ShardReader::Init(const std::vector &file_paths, bool loa for (const auto &rg : row_group_summary) { num_rows_ += std::get<3>(rg); } + auto disk_size = page_size_ * row_group_summary.size(); + auto compression_size = shard_header_->GetCompressionSize(); + total_blob_size_ = disk_size + compression_size; + MS_LOG(INFO) << "Blob data size, on disk: " << disk_size << " , addtional uncompression: " << compression_size + << " , Total: " << total_blob_size_; MS_LOG(INFO) << "Get meta from mindrecord file & index file successfully."; @@ -272,6 +280,11 @@ std::vector> ShardReader::ReadRowGroupSummar return row_group_summary; } +MSRStatus ShardReader::GetTotalBlobSize(int64_t *total_blob_size) { + *total_blob_size = total_blob_size_; + return SUCCESS; +} + MSRStatus ShardReader::ConvertLabelToJson(const std::vector> &labels, std::shared_ptr fs, std::vector>> &offsets, int shard_id, diff --git a/mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc b/mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc index bf702180abd..889a461b079 100644 --- a/mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc +++ b/mindspore/ccsrc/minddata/mindrecord/io/shard_writer.cc @@ -28,11 +28,9 @@ using mindspore::MsLogLevel::INFO; namespace mindspore { namespace mindrecord { ShardWriter::ShardWriter() - : shard_count_(1), - header_size_(kDefaultHeaderSize), - page_size_(kDefaultPageSize), - row_count_(0), - schema_count_(1) {} + : shard_count_(1), header_size_(kDefaultHeaderSize), page_size_(kDefaultPageSize), row_count_(0), schema_count_(1) { + compression_size_ = 0; +} ShardWriter::~ShardWriter() { for (int i = static_cast(file_streams_.size()) - 1; i >= 0; i--) { @@ -201,6 +199,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) { if (ret == FAILED) { return FAILED; } + compression_size_ = shard_header_->GetCompressionSize(); ret = Open(real_addresses, true); if (ret == FAILED) { MS_LOG(ERROR) << "Open file failed"; @@ -614,7 +613,9 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map // compress blob if (shard_column_->CheckCompressBlob()) { for (auto &blob : blob_data) { - blob = shard_column_->CompressBlob(blob); + int64_t compression_bytes = 0; + blob = shard_column_->CompressBlob(blob, &compression_bytes); + compression_size_ += compression_bytes; } } @@ -1177,6 +1178,11 @@ MSRStatus ShardWriter::WriteShardHeader() { MS_LOG(ERROR) << "Shard header is null"; return FAILED; } + + int64_t compression_temp = compression_size_; + uint64_t compression_size = compression_temp > 0 ? compression_temp : 0; + shard_header_->SetCompressionSize(compression_size); + auto shard_header = shard_header_->SerializeHeader(); // Write header data to multi files if (shard_count_ > static_cast(file_streams_.size()) || shard_count_ > static_cast(shard_header.size())) { diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc index 47e001e8f88..eb4229be9cc 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_column.cc @@ -24,7 +24,15 @@ namespace mindspore { namespace mindrecord { ShardColumn::ShardColumn(const std::shared_ptr &shard_header, bool compress_integer) { auto first_schema = shard_header->GetSchemas()[0]; - auto schema = first_schema->GetSchema()["schema"]; + json schema_json = first_schema->GetSchema(); + Init(schema_json, compress_integer); +} + +ShardColumn::ShardColumn(const json &schema_json, bool compress_integer) { Init(schema_json, compress_integer); } + +void ShardColumn::Init(const json &schema_json, bool compress_integer) { + auto schema = schema_json["schema"]; + auto blob_fields = schema_json["blob_fields"]; bool has_integer_array = false; for (json::iterator it = schema.begin(); it != schema.end(); ++it) { @@ -52,8 +60,6 @@ ShardColumn::ShardColumn(const std::shared_ptr &shard_header, bool column_name_id_[column_name_[i]] = i; } - auto blob_fields = first_schema->GetBlobFields(); - for (const auto &field : blob_fields) { blob_column_.push_back(field); } @@ -282,8 +288,9 @@ ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) { return it_blob == blob_column_id_.end() ? ColumnInRaw : ColumnInBlob; } -std::vector ShardColumn::CompressBlob(const std::vector &blob) { +std::vector ShardColumn::CompressBlob(const std::vector &blob, int64_t *compression_size) { // Skip if no compress columns + *compression_size = 0; if (!CheckCompressBlob()) return blob; std::vector dst_blob; @@ -295,7 +302,9 @@ std::vector ShardColumn::CompressBlob(const std::vector &blob) // Compress and return is blob has 1 column only if (num_blob_column_ == 1) { - return CompressInt(blob, int_type); + dst_blob = CompressInt(blob, int_type); + *compression_size = static_cast(blob.size()) - static_cast(dst_blob.size()); + return dst_blob; } // Just copy and continue if column dat type is not int32/int64 @@ -319,6 +328,7 @@ std::vector ShardColumn::CompressBlob(const std::vector &blob) i_src += kInt64Len + num_bytes; } MS_LOG(DEBUG) << "Compress all blob from " << blob.size() << " to " << dst_blob.size() << "."; + *compression_size = static_cast(blob.size()) - static_cast(dst_blob.size()); return dst_blob; } diff --git a/mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc b/mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc index 9f75d84e7ac..a90753ef69c 100644 --- a/mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc +++ b/mindspore/ccsrc/minddata/mindrecord/meta/shard_header.cc @@ -33,7 +33,9 @@ using mindspore::MsLogLevel::ERROR; namespace mindspore { namespace mindrecord { std::atomic thread_status(false); -ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0) { index_ = std::make_shared(); } +ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0), compression_size_(0) { + index_ = std::make_shared(); +} MSRStatus ShardHeader::InitializeHeader(const std::vector &headers, bool load_dataset) { shard_count_ = headers.size(); @@ -54,6 +56,7 @@ MSRStatus ShardHeader::InitializeHeader(const std::vector &headers, bool l ParseShardAddress(header["shard_addresses"]); header_size_ = header["header_size"].get(); page_size_ = header["page_size"].get(); + compression_size_ = header.contains("compression_size") ? header["compression_size"].get() : 0; } if (SUCCESS != ParsePage(header["page"], shard_index, load_dataset)) { return FAILED; @@ -146,9 +149,12 @@ std::pair ShardHeader::BuildSingleHeader(const std::string &fil return {FAILED, json()}; } json raw_header = ret.second; + uint64_t compression_size = + raw_header.contains("compression_size") ? raw_header["compression_size"].get() : 0; json header = {{"shard_addresses", raw_header["shard_addresses"]}, {"header_size", raw_header["header_size"]}, {"page_size", raw_header["page_size"]}, + {"compression_size", compression_size}, {"index_fields", raw_header["index_fields"]}, {"blob_fields", raw_header["schema"][0]["blob_fields"]}, {"schema", raw_header["schema"][0]["schema"]}, @@ -343,6 +349,7 @@ std::vector ShardHeader::SerializeHeader() { s += "\"index_fields\":" + index + ","; s += "\"page\":" + pages[shardId] + ","; s += "\"page_size\":" + std::to_string(page_size_) + ","; + s += "\"compression_size\":" + std::to_string(compression_size_) + ","; s += "\"schema\":" + schema + ","; s += "\"shard_addresses\":" + address + ","; s += "\"shard_id\":" + std::to_string(shardId) + ","; diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index e035146225a..99fba20a507 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -3083,20 +3083,22 @@ def _cpp_sampler_fn(sampler, dataset): yield tuple([np.array(x, copy=False) for x in val]) -def _cpp_sampler_fn_mp(sampler, dataset, num_worker): +def _cpp_sampler_fn_mp(sampler, dataset, num_worker, multi_process): """ Multiprocessing generator function wrapper for mappable dataset with cpp sampler. """ indices = sampler.get_indices() - return _sampler_fn_mp(indices, dataset, num_worker) + sample_fn = SamplerFn(dataset, num_worker, multi_process) + return sample_fn.process(indices) -def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker): +def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker, multi_process): """ Multiprocessing generator function wrapper for mappable dataset with python sampler. """ indices = _fetch_py_sampler_indices(sampler, num_samples) - return _sampler_fn_mp(indices, dataset, num_worker) + sample_fn = SamplerFn(dataset, num_worker, multi_process) + return sample_fn.process(indices) def _fetch_py_sampler_indices(sampler, num_samples): @@ -3130,63 +3132,92 @@ def _fill_worker_indices(workers, indices, idx): return idx -def _sampler_fn_mp(indices, dataset, num_worker): +class SamplerFn: """ - Multiprocessing generator function wrapper master process. + Multiprocessing or multithread generator function wrapper master process. """ - workers = [] - # Event for end of epoch - eoe = multiprocessing.Event() + def __init__(self, dataset, num_worker, multi_process): + self.workers = [] + self.num_worker = num_worker + self.multi_process = multi_process + # Event for end of epoch + if multi_process is True: + self.eoe = multiprocessing.Event() + self.eof = multiprocessing.Event() + else: + self.eoe = threading.Event() + self.eof = threading.Event() + # Create workers + for _ in range(num_worker): + if multi_process is True: + worker = _GeneratorWorkerMp(dataset, self.eoe, self.eof) + else: + worker = _GeneratorWorkerMt(dataset, self.eoe, self.eof) + worker.daemon = True + self.workers.append(worker) - # Create workers - for _ in range(num_worker): - worker = _GeneratorWorker(dataset, eoe) - worker.daemon = True - workers.append(worker) + def process(self, indices): + """ + The main process, start the child process or child thread, and fill the index queue, + get the result from the result and return. + """ + # Fill initial index queues + idx_cursor = 0 + idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor) - # Fill initial index queues - idx_cursor = 0 - idx_cursor = _fill_worker_indices(workers, indices, idx_cursor) + # Start all workers + for w in self.workers: + w.start() - # Start all workers - for w in workers: - w.start() + # Fetch results + for i in range(len(indices)): + # Fetch result and put index + try: + result = self.workers[i % self.num_worker].get() + except queue.Empty: + raise Exception("Generator worker process timeout") + except KeyboardInterrupt: + self.eof.set() + for w in self.workers: + w.terminate() + w.join() + raise Exception("Generator worker receives KeyboardInterrupt") + if idx_cursor < len(indices): + idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor) + # Set eoe event once all indices are sent + if idx_cursor == len(indices) and not self.eoe.is_set(): + self.eoe.set() + yield tuple([np.array(x, copy=False) for x in result]) - # Fetch results - for i in range(len(indices)): - # Fetch result and put index - try: - result = workers[i % num_worker].get() - except queue.Empty: - raise Exception("Generator worker process timeout") - except KeyboardInterrupt: - for w in workers: - w.terminate() + def __del__(self): + self.eoe.set() + self.eof.set() + if self.multi_process is False: + for w in self.workers: w.join() - raise Exception("Generator worker receives KeyboardInterrupt") - if idx_cursor < len(indices): - idx_cursor = _fill_worker_indices(workers, indices, idx_cursor) - # Set eoe event once all indices are sent - if idx_cursor == len(indices) and not eoe.is_set(): - eoe.set() - yield tuple([np.array(x, copy=False) for x in result]) -def _generator_worker_loop(dataset, idx_queue, result_queue, eoe): +def _generator_worker_loop(dataset, idx_queue, result_queue, eoe, eof): """ - Multiprocessing generator worker process loop. + Multiprocessing or multithread generator worker process loop. """ while True: # Fetch index, block try: - idx = idx_queue.get() + idx = idx_queue.get(timeout=10) except KeyboardInterrupt: raise Exception("Generator worker receives KeyboardInterrupt") + except queue.Empty: + if eof.is_set() or eoe.is_set(): + raise Exception("Generator worker receives queue.Empty") + continue if idx is None: # When the queue is out of scope from master process, a None item can be fetched from the queue. # Upon receiving None, worker process should check if EOE is set. assert eoe.is_set(), "" return + if eof.is_set(): + return # Fetch data, any exception from __getitem__ will terminate worker and timeout master process result = dataset[idx] # Send data, block @@ -3195,17 +3226,19 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eoe): except KeyboardInterrupt: raise Exception("Generator worker receives KeyboardInterrupt") del result, idx + if eoe.is_set() and idx_queue.empty(): + return -class _GeneratorWorker(multiprocessing.Process): +class _GeneratorWorkerMt(threading.Thread): """ - Worker process for multiprocess Generator. + Worker process for multithread Generator. """ - def __init__(self, dataset, eoe): - self.idx_queue = multiprocessing.Queue(16) - self.res_queue = multiprocessing.Queue(16) - super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe)) + def __init__(self, dataset, eoe, eof): + self.idx_queue = queue.Queue(16) + self.res_queue = queue.Queue(16) + super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof)) def put(self, item): """ @@ -3217,7 +3250,30 @@ class _GeneratorWorker(multiprocessing.Process): """ Get function for worker result queue. Block with timeout. """ - return self.res_queue.get() + return self.res_queue.get(timeout=10) + + +class _GeneratorWorkerMp(multiprocessing.Process): + """ + Worker process for multiprocess Generator. + """ + + def __init__(self, dataset, eoe, eof): + self.idx_queue = multiprocessing.Queue(16) + self.res_queue = multiprocessing.Queue(16) + super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof)) + + def put(self, item): + """ + Put function for worker index queue. Never block. Raise queue.Full on failure. + """ + self.idx_queue.put_nowait(item) + + def get(self): + """ + Get function for worker result queue. Block with timeout. + """ + return self.res_queue.get(timeout=10) def __del__(self): self.terminate() @@ -3280,6 +3336,8 @@ class GeneratorDataset(MappableDataset): When this argument is specified, 'num_samples' will not effect. Random accessible input is required. shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only when num_shards is also specified. Random accessible input is required. + python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This + option could be beneficial if the python operation is computational heavy (default=True). Examples: >>> import mindspore.dataset as ds @@ -3316,12 +3374,14 @@ class GeneratorDataset(MappableDataset): @check_generatordataset def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None, - num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None): + num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None, + python_multiprocessing=True): super().__init__(num_parallel_workers) self.source = source self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) self.num_samples = num_samples self.num_shards = num_shards + self.python_multiprocessing = python_multiprocessing if column_names is not None and not isinstance(column_names, list): column_names = [column_names] @@ -3403,12 +3463,16 @@ class GeneratorDataset(MappableDataset): sampler_instance.set_num_rows(len(self.source)) sampler_instance.initialize() if new_op.num_parallel_workers > 1: - new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, self.source, new_op.num_parallel_workers)) + new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, self.source, + new_op.num_parallel_workers, + self.python_multiprocessing)) else: new_op.source = (lambda: _cpp_sampler_fn(sampler_instance, self.source)) else: if new_op.num_parallel_workers > 1: - new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, self.source, new_op.num_parallel_workers)) + new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, self.source, + new_op.num_parallel_workers, + self.python_multiprocessing)) else: new_op.source = (lambda: _py_sampler_fn(new_op.sampler, new_op.num_samples, self.source)) else: diff --git a/mindspore/dataset/engine/graphdata.py b/mindspore/dataset/engine/graphdata.py index 8641761daa3..065d15f7b3a 100644 --- a/mindspore/dataset/engine/graphdata.py +++ b/mindspore/dataset/engine/graphdata.py @@ -16,8 +16,11 @@ graphdata.py supports loading graph dataset for GNN network training, and provides operations related to graph data. """ +import atexit +import time import numpy as np -from mindspore._c_dataengine import Graph +from mindspore._c_dataengine import GraphDataClient +from mindspore._c_dataengine import GraphDataServer from mindspore._c_dataengine import Tensor from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \ @@ -34,14 +37,52 @@ class GraphData: dataset_file (str): One of file names in dataset. num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel (default=None). + working_mode (str, optional): Set working mode, now support 'local'/'client'/'server' (default='local'). + + - 'local', used in non-distributed training scenarios. + + - 'client', used in distributed training scenarios, the client does not load data, + but obtains data from the server. + + - 'server', used in distributed training scenarios, the server loads the data + and is available to the client. + + hostname (str, optional): Valid when working_mode is set to 'client' or 'server', + set the hostname of the graph data server (default='127.0.0.1'). + port (int, optional): Valid when working_mode is set to 'client' or 'server', + set the port of the graph data server, the range is 1024-65535 (default=50051). + num_client (int, optional): Valid when working_mode is set to 'server', + set the number of clients expected to connect, and the server will allocate corresponding + resources according to this parameter (default=1). + auto_shutdown (bool, optional): Valid when working_mode is set to 'server', + Control when all clients have connected and no client connected to the server, + automatically exit the server (default=True). """ @check_gnn_graphdata - def __init__(self, dataset_file, num_parallel_workers=None): + def __init__(self, dataset_file, num_parallel_workers=None, working_mode='local', hostname='127.0.0.1', port=50051, + num_client=1, auto_shutdown=True): self._dataset_file = dataset_file + self._working_mode = working_mode if num_parallel_workers is None: num_parallel_workers = 1 - self._graph = Graph(dataset_file, num_parallel_workers) + + def stop(): + self._graph_data.stop() + atexit.register(stop) + + if working_mode in ['local', 'client']: + self._graph_data = GraphDataClient(dataset_file, num_parallel_workers, working_mode, hostname, port) + + if working_mode == 'server': + self._graph_data = GraphDataServer( + dataset_file, num_parallel_workers, hostname, port, num_client, auto_shutdown) + try: + while self._graph_data.is_stoped() is not True: + time.sleep(1) + except KeyboardInterrupt: + # self._graph_data.stop() + raise Exception("Graph data server receives KeyboardInterrupt") @check_gnn_get_all_nodes def get_all_nodes(self, node_type): @@ -62,7 +103,9 @@ class GraphData: Raises: TypeError: If `node_type` is not integer. """ - return self._graph.get_all_nodes(node_type).as_array() + if self._working_mode == 'server': + raise Exception("This method is not supported when working mode is server") + return self._graph_data.get_all_nodes(node_type).as_array() @check_gnn_get_all_edges def get_all_edges(self, edge_type): @@ -83,7 +126,9 @@ class GraphData: Raises: TypeError: If `edge_type` is not integer. """ - return self._graph.get_all_edges(edge_type).as_array() + if self._working_mode == 'server': + raise Exception("This method is not supported when working mode is server") + return self._graph_data.get_all_edges(edge_type).as_array() @check_gnn_get_nodes_from_edges def get_nodes_from_edges(self, edge_list): @@ -99,7 +144,9 @@ class GraphData: Raises: TypeError: If `edge_list` is not list or ndarray. """ - return self._graph.get_nodes_from_edges(edge_list).as_array() + if self._working_mode == 'server': + raise Exception("This method is not supported when working mode is server") + return self._graph_data.get_nodes_from_edges(edge_list).as_array() @check_gnn_get_all_neighbors def get_all_neighbors(self, node_list, neighbor_type): @@ -123,7 +170,9 @@ class GraphData: TypeError: If `node_list` is not list or ndarray. TypeError: If `neighbor_type` is not integer. """ - return self._graph.get_all_neighbors(node_list, neighbor_type).as_array() + if self._working_mode == 'server': + raise Exception("This method is not supported when working mode is server") + return self._graph_data.get_all_neighbors(node_list, neighbor_type).as_array() @check_gnn_get_sampled_neighbors def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types): @@ -155,7 +204,9 @@ class GraphData: TypeError: If `neighbor_nums` is not list or ndarray. TypeError: If `neighbor_types` is not list or ndarray. """ - return self._graph.get_sampled_neighbors( + if self._working_mode == 'server': + raise Exception("This method is not supported when working mode is server") + return self._graph_data.get_sampled_neighbors( node_list, neighbor_nums, neighbor_types).as_array() @check_gnn_get_neg_sampled_neighbors @@ -182,7 +233,9 @@ class GraphData: TypeError: If `neg_neighbor_num` is not integer. TypeError: If `neg_neighbor_type` is not integer. """ - return self._graph.get_neg_sampled_neighbors( + if self._working_mode == 'server': + raise Exception("This method is not supported when working mode is server") + return self._graph_data.get_neg_sampled_neighbors( node_list, neg_neighbor_num, neg_neighbor_type).as_array() @check_gnn_get_node_feature @@ -207,10 +260,12 @@ class GraphData: TypeError: If `node_list` is not list or ndarray. TypeError: If `feature_types` is not list or ndarray. """ + if self._working_mode == 'server': + raise Exception("This method is not supported when working mode is server") if isinstance(node_list, list): node_list = np.array(node_list, dtype=np.int32) return [ - t.as_array() for t in self._graph.get_node_feature( + t.as_array() for t in self._graph_data.get_node_feature( Tensor(node_list), feature_types)] @@ -236,10 +291,12 @@ class GraphData: TypeError: If `edge_list` is not list or ndarray. TypeError: If `feature_types` is not list or ndarray. """ + if self._working_mode == 'server': + raise Exception("This method is not supported when working mode is server") if isinstance(edge_list, list): edge_list = np.array(edge_list, dtype=np.int32) return [ - t.as_array() for t in self._graph.get_edge_feature( + t.as_array() for t in self._graph_data.get_edge_feature( Tensor(edge_list), feature_types)] @@ -252,7 +309,9 @@ class GraphData: dict: Meta information of the graph. The key is node_type, edge_type, node_num, edge_num, node_feature_type and edge_feature_type. """ - return self._graph.graph_info() + if self._working_mode == 'server': + raise Exception("This method is not supported when working mode is server") + return self._graph_data.graph_info() @check_gnn_random_walk def random_walk( @@ -285,5 +344,7 @@ class GraphData: TypeError: If `target_nodes` is not list or ndarray. TypeError: If `meta_path` is not list or ndarray. """ - return self._graph.random_walk(target_nodes, meta_path, step_home_param, step_away_param, - default_node).as_array() + if self._working_mode == 'server': + raise Exception("This method is not supported when working mode is server") + return self._graph_data.random_walk(target_nodes, meta_path, step_home_param, step_away_param, + default_node).as_array() diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index eba6539b651..dbe4ff6212b 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -18,6 +18,7 @@ Built-in validators. """ import inspect as ins import os +import re from functools import wraps import numpy as np @@ -912,16 +913,36 @@ def check_split(method): return new_method +def check_hostname(hostname): + if len(hostname) > 255: + return False + if hostname[-1] == ".": + hostname = hostname[:-1] # strip exactly one dot from the right, if present + allowed = re.compile("(?!-)[A-Z\\d-]{1,63}(?1*irG;*E-xqan4NdG2g1d&&LiUA_m0)B`w$&!KD=93^cS}{**|B}IxV4WP%%W)pf#;U@0>{jmex>g0S|!&uiz5fQ{v)fd=LgZwqYm_) zO0!{0jtuIqCp++7>yZX-=ggK>Qzxp<$x(S;^G>_zqvs2qJ)IMY|BjuE6uNtZ=sSz% zrP+l`rOT5mlj^*p?<+;y6EynHn%CTGsSUBcNKVddR;#`fM4nVL{JPaSjcAV_SYFU= zOxq*2P94!si6oDv+IDd)pX`jAwmiA^U$!h374pUbuy69VHM^#qLxQ(EYkC1K$?|q-z zH$E2a>pP#rziPnYdGqkYJISeU@3=S$=?$_0t?iDhLHpX>zH2N#`-5ITwrr)|_noM} zn|{MLm8$+89=LwsR@4!Vy`7z$dE4=#4x=ty;oOLM&?!g!p9W}v255i=Xn+Q2fCgxQ z255i=Xn+Q2fCgxQ255i=Xn+Q2fCgxQ255i=Xn+Q2fCgxQ255i=Xn+Q2fCgxQ255i= zXn+Q2fCgxQ255i=Xn+Q2fCipO15@dR84vRmDoXJ>_2EDF~t7!=+g|xPtrTE z$8h|hjl>Yg564gcuOX}-$>_il!}@_X5<{#XtRJi&>Cuo>iexL_X@CZ3fCgxQ255i= zXn+Q2fCgxQ255i=Xn+Q2fCgxQ255i=Xn+Q2fCgxQ255i=Xn+Q2fCgxQ255i=Xn+Q2 zfCgxQ255i=Xn+Q2fCgxQ255i=Xn+Q2fCgxQ255i=Xn+Q2fCgxQ255i=Xn+Q2fCgxQ z255i=Xn+Q2fCgyb|7c(=KB+$+SV7>q<-Nf9yK%qjdVWxH?a;XLz;`OmhOIsf-4E)n z*Z%9D;9koyj$*HBuu~fBv<5q)!Om*1a~kZt2D_laE^4q#8tk$LdrgB~(O|D@u&WyE zng+YB!ER`8myziRyEk#W7zQlbGOr- zyNw=UWNSluV>AE6=}>~m1Mt$^lH}Z*OA8THA(gx zNUvqheiLcTUYBIQh4gyX?6;A|>_~G8nZCIMNc)ba*@s!Ifu22 zv~y85tX`zEMb2UUBJEt14J#OFTI3woFw)LNEr(T%(7+j3G>^eLMrhzXCYt?WB_lNE zp|^;#VJ#z_Jum60u$mDXb6K-tJtLhxC&`8tjnJ6Rnhk3j>FjAqHmqud#?`FZu&$BL zzADLvm5tD-WzB}QjdXTRk`1dHp;3kGTM`=9H$vkU&{YWyD;();Q9p+@j`VqpvSF1Y zoh`awgLRJ3!22~(Hmr1{vu%krSnCK4JF9b8?FbDAXwjNQ9q%1*&BBqi+(jC*Mb3Xk z8l8)>e?uCxMb3ASM(3jJ-NR^+^9IuBT-5R=(iOPgqVe5A8pn4&odj^OkTh;lTl_e zql3)kJ})+=|D(*&a2QPlqnTl}Bp59b2d+Ap{LYDE^9+^@>&**Tc=b0OIHLjpe-I-O)9hJS=!1R@jgXK{11%uXh>C!S z3W&3Y!MJKTGlpXlhT|B<${4e=j>C?I>&&Xd&JvExIhmf;@y*O%=TXP2*VfW7GKSK_TShkzjHPEYUIqYWSvt)yj04|X_;z!$VVHdG z<^*oSUwCvdGgiDN%AYfKaflHO`RDk-nyd;~6|gE`Rlur%RROC4Rt2mISQW4;@K;-b zdb=x<$$)iZ1B07TTWw`)RbOM}>gobHRB95nlUh1g-IYr>HsM=) zL&KDVA>fY0?iEN^I9PF9HK+*lvt z{SVHy75chU*SMl@=Hxn7qzA^vMn?v>k0JcTWZ^3&=c(7{4Udd&8|xbxy6#rjRJUH6 zJL~$~ZD`#VZRp#xXJ4N^wqazHGI!36xp!=xTq0hDYq9C(*1FY|t=;MR>hALEf4pmp znt6S7t}pi5{p@S^bEZt0ZgWK{Y z4>%{j0r`tbbje@HhvfD0O1VUKNWYU_lfEnMm-?ioQc}Ds{ziORd{i74JH>foNcdFv zC*iE{ps-bF5@rY@f0_RU{|tYS@8=iuQSMXjE$&(F5I4jv;}Y!W>^to9>``_DTg9e5 zS3U1~Uho|EZ1U84iam_`g8L=+3HMfagS*`Aa9wo0;yUdbb2YmvTps6#&U4Pkoa4@% zbEZ>peB}74z6FT#f6f3N+)KA&Rpu@>1lpm>E~m3Y^mn*HN6o z$gapsF(YQ8<)8`Y3{sqd$l}PZm>x4wCC+Jr0~Dt}{AFZuOpWPi8FqSLKgH<_p9_Ck zWX4o<8|cBy>Z3So!w16Wii{!?EyeZefomzwns8P4K#^W#pd~n`9j>7``LHWoRiqZ_ zXfYVXIeCiH8+twLikd|#T7+{3VK2q$2^|i-9yOvSS_lSkP7lTD4mE@hNA;+IZpAqR zu$$s^g``kJRE_Fr0a$}`x+qR(@U4&(F{3J)k8{?*PKwhJd?@%<#E6(^9%#ck9TcZM zm(IW<$i*wpwJH^Qb!@*ocjp%3&$l;tE#c2z?8w`idh>B+8oE&VUIIV&21l|oB zVH3^5`%e#Sr8q5teBe7_J!~Kq=QP04nV{r^$aYkP4Y$72N`Q z@v@pIPNV+`|HY6IGSLjYANImVinH236sN&o>K_WJAstoV{je7{ zP@H<-NB+{F8B)nKjG zZ-#GEPz~y+9Mt2ST8dNSz2ch@FoP;8!#VY^hT^RBp7mY{7y%QNf;ybDlHyc*$Gm3) zdcZ&>IHwL)Q=BUAJnvXQ4d|#C)Zm;dinBug+&j;222_;AIW=$v#aS-DEPw7d{3e2+ z59chWIF<4q`DMTEH&6!W^ubDsvrJwp@A0dC9i_38hs!9=Z89S-^_hMZDIkxRbsNQ5 zD!nE%KEr3C6y6W>a4E%EA{~@o^XWbVC2>vywML4Gm-b!&6h>hY=uj#~7DIHv>7 zr8skh4&j8X$~p?-<0}v6P@LI9ROpaQS;ZH?JU+f=Q=D1+`$AMQBoq1Zc`OfSQ5=f{8eM{>j6s6lV(iEc*vR7YxMWoF+Jh;*_%*`>db}I`ZK2Pac+2oHBMM ztMR6wA~(*dhh-F})bmGnCU5X2a^dq&9+py^68y^bkG#$s$cc06UsjF0!I`{@>^P?eW+@JI|Jk#EGdL63z*>;cz^^*&_1CKz_bcu{b2?`r z9J5x*d{x0OZSlmW-TU3IaA{80zNs5qg9e17gTwtebV4aen2ys`=b=-5rb3p@xN}M9ZTv1ofv(R%y z0|QEoA{03nYb~8ptaliV$=>@3@RN^(hVXDnW`6aEv*gciCMz4fH4h zir{xFaEv+6y6hV0R{RvfXTR!L;2dzC)j*fxqX=I6Ir~+o*EyhpPQ^D?TSPZM7v5d_;?SUK-VYM@OKC<1RgftB6% z^BQPXc#6Q;x@;%xQ|-Go(4ue@fwd*D@}7OF2AUO?B6z?>tQ@kvr-3HLLlNBINvzDb z9nwIf;-&~LumLNVZ1Xj+T5(YXC#b;6NpML64T_T@IG8I~sRAc8P_H;Bf}MGRxx#z` zsx+`FWyb=z`G)nz(QO;Y#uLmZ8mNog0xS!F$AcXipB}Cq{foz2m?tz)8wCLmi{Itq zxrLedEwNMM@cpCg5=Q9lmebU@lEZITA04yj&RCe4+mN{SSeSn*5o z6Y(SQU&Ob?e-K~6*Py4xzY&j%cZ<8kE#i>ajb8z*5EqEkMJPr@K?K5Q!hZ-CgM2HEp;Nbtnf6V`mf0zG9{x$wb{ImRb z`BVH6em}pR-^lmzIlhix%FpJ@`6Tb>-P~2q9;I?xc zxL&T2tK{ZzWn7FCIfng&y~Mu9o@ZZW&$3Ukr`aR;G^n#fY$v-4m&}^13Ro4eDqvN> zs(@92|Bn?AZoJO2F{hbF3ynvZ(}l(<=8;0a2JPoZ&;xx3J~i#b?m9ANG$G$xn>g~onng1`2<6n_!3kJ-=NV6b-|!8`X7+;bg!;EtUHx9=c0ww>U%F@mGp2D$Y$CY2kzm7Wg7pmq zSJe}&TSc(8j$lnK!Id=xt5*`NswTLiis14U1S^*lTvkc&wq*pD-bQfAQi6+@5L~pF z;KD@&Z(T@m!L0=6FCaK?KEb*32+o;HaP}O6vt|=iXAzvK61-(5!5Oy@oIZnK#dLzx zDhN)UMsUhhg5^^PmX#AMEhAV`O0c+uV78bb%o5B%g6RxFB~37;5KN{BCXxi>34*aW z!J-(!Xc56klwdeQFcc;j3=s?j3Hk#BeSU&oA3@nmP?8CX5H+^sofo z9)d16L8ptL!%5KYAZW9*_(OonY4T%)>k}*D>o$=8#NhwdWL3bbfK>sj0#*gA3Ro4e zDqvN>s(@7is{&R9tP1?KRlo^sfof|Bhp=$*O== z0jmO51*{5K6|gE`Rp5WS0@;b8fbCl&88N4ej}@OUSIc!Yj)!$wo1KW#tVfEQi;tC= z&1FWJiFV+uzU)MpW}VEwQ|vF(%M7#~514_oLNx0{_Eh$r zGPO)cV>qinI}xN=4`;iwr%KH-6>S6klUV_p^-wmJ?J6}&O*D$vejq#Hr&$le3)xty zUTPp650rx|;iFl{;p6Z^samR|t$2y+vJ+mK^#JULkC&LGD%t|pO|HF6vyMRs`%8=x z6XD;IgQ4t%M6>SCd@@aZjR?iv; z0qkWboHXm8@}lx-R?X^Y2yY|Ka?q^1lyT)nXl7Nk4rk$Qw9~8u%0guv8qh?8cwfa? zHkvh&x~eRMIyBG#&cgdDpjrDJGnR29-XY`DLuz|;Gx|8DUN&YP5$fy||<#A3e+(U78 zCl4immNqjg>IJnpXE(*!m0XoPls3{P!Xrb0Rd5%@xg#kgSEco|fx5veoO1`o8BhE& zDWui3j=FHpYB)}Db|#J|eyNyg6?KBuIA)M%5$>Qk+Y`Y=tD-9g zY6p!tXFJ6ii~l+iR8&PrIh>P&V-#mw{B-=+DN|8V8y@3j4cta?M&mv4(YZ{O8&9Ovl#ZIQ(+RgyoGr04vEL@ml!}`0SUMeW V3&qi5gRwJ7BWa>WJf2MFe*=)}@SXqw literal 24576 zcmeI4dvH|OeaG+F_kHhET1hMI>Z||(0t5&wg8)fL=(Xs%dN1g;M@S$pLM$L-*&r<2 zGM2Gy3oNJ9LsRUbDdP|i;}Q>(5RdCpmnIqyNr+1mjY}OSr5-h5TqV|c`V+`!{Wz2S z)0xhs_u4Ccj_!x=-u>zBJ?FsLTUQh6jgdh2zJa|RvA|l+&LPC*2Lc?&nc!i8#}}Ic zVE$)cZ0HO4Bi~Kj$}PW;)mhFcM>+X>>OaCaebXzTS3s|TUID!VdIj_f=oQc_pjSYz zz@KdeDvXxI^mNn}>*(wwb)^NhMQxP@+l%95+FDy&5H1dc3kpNUf%wle=e0n%xG5Z{ ztb#{ZgmqQFEeslNQ2fJfEy#tKBISYFq>7Rc_IE4GLf$G|_?FF^Xfr{eh%m=@` z`%jhq;BX#1?ftK7?|;o%x^#uXl30*}xLz$v1`q7-gIRrB$HCaX`|Y-cQ=x730vO@V zW}AP*nc;!d&*sshex{D8d(?7uscKVxt-Pnari>_EN}-ajaPqI@AIN_x56P|a202NZ zm3}V0C7qIDQk}F$a)>kHPs9u2s2COVMV~M$+z>7bV?vKmBn0?5{w9BkALsk{Qa*#{ zY`1KeZ4yK5HH_H<@$H zcGF$cHPboMF;m2pYjPWBjMt6ljU&c(W1i7#xM!F$Tri9pqK15f56z+*=pq_JJ*Wr; zxViZk$J(&wLcGCXv%y2PA)5jI3q`1s-$L<3djlJJq%}e((a`5${kfA<){HQ;oS_UHz_M=xALxXhm@fv)PQ># zPLKDKq%7sE@-8VwwUE=paCUhwdOuNyl(VD+a%%A|hSTjG@Lp7!lp(SW)uU$I&2We} z$2*|pC{3goHKTe=7)~_tQ*Vx9S8_-Zs)n2>!|6)AocO7HSFw|=kW-Dj7*1#6k;KdL zHTf?d`QWtRc81fIa5dp`IU*k; zTTlx;R~y6GnQ%Pes+=oF$Y!Wd3*O0ab|e%f9GBg4F4+V*^>_!vi6mGOiliCYP4ZAD zePBx$p$Z26X&7P3wgft?ZCl5nT2X1CK zO&-M)lG>#al8bggP7}jvbWeE{DNkxA4?)fj+{kbm+%LMPB(Ic5)}uPeX<#_@?y&ns z@t))*>ma8N*E5{3JHZ_mr^I_C2ZbRg%y8;lH{A(v&6y%=At#LM7*4J0TdtdMRk}dd z!1>dHYZ*?BE8_YVT(?HaYRIX=H4LZPmF9}Tl`TrL!D+?S45!LD?Mf4UVm?`gTH(2> z7*3`0+s`j2?F=X6T<2#N!aFre-%b-53xSZjXIr<#mf$Mk=SqeE-xQyYHI#xRR zgn&>)GEoKOlro$W`;22HKPLpp63D5*B@Abq{Z0D}^cHhuF)D+cZ49T_9<#p*{mD(T z2y)7BF~cdcueZmb$GJo@PzmG|F`TXHtbIN7N#i64IVE^2!zonXR%fBt>LVDnK~5pV zDNu*hx1nDvCFzjUh6@-@zPec*;yFHp1i*>le1@|{<gefB%hzg{M(kkf)UF`PVQn=%4D@C5O}c^JWY3}>TaQ?@}L-A_^= zrx|Z#I2+{m6&v*8A(9MzN(66UI1kHX^7}TEEt4ccP7{8Z;pEDd@|bnnW+GnbD&X~YjPob}QNvK>YfS4aZ%rxCoK;jEJ;qz_=cF-bg-(}348oE)h^ znt%~WjJV%l_mouDY z{BMLz^Bs#s1i1c0@G^$8l)u3L2F8JR2oE_`cqzll%w?uD)Z%wB1Sk9+ zVldUYprK~JlOZ^abEaIF35{x~$?sqYcH>p!988p=8fx_08G;H`ylU_n@-@`pR~dq0 zEH#dxS%XhQ^?rpR$VSOn3bU|T4Tb$OLy!zpMhPZqH#Ah|ml%R*m@rJi9PXlqYW*TZ z5DZO*37F!IX{g38Fa#c6K@-dXdo)z-=NW{d!xj7WjP^I6> z5G?Tg=eV2PoQAghEeyeovQQ6qiMy$xkl)M@Oxz60;>Nj48mdS$EeL3SKDzMcHNu~s z>Ep&VRPHmve%NP2D8An?$XzP!(@y80vaE%k5Jv+61JgnCqsslDn>wN5QnH>qpXC8|$#s21fj zWk&gT{Ps&H+eR5Q8hTeZG%&Jz%SWc8B8A+c?{~=9FKb78>ekfg%{ziIL zdPy3U9+L*7T~b7(rxhLX{VeWXmaf}-l?~hC2A7PGi$Akq&j~}J@*y9w39;5i^5XB>pQhekH z#lc4?9v-B4=rF~DhbYDlQali&IBO0txL2dtx0m9cK8m~dQ0(1J zv8R{ft{#fryC{-wicvzbD@w7mi(*G7#r6)0ZS53ywo%-%lVW5C#nuSLmR5?*EfkxY zDK<7yY-prd-#{^3Pq8jcv9^w4O)bUh8j4lb6f3JJZm*;m+D@?|M6tYrVp%!G(lUxA zr4+Z7P%Pd?v8b5h)*_09TPYS4Qp_))xFw(B<}DOAZKjyFiQ>jQiW@dke0T%J+=nSX zluL2_LloDor(1@f58#iWVzHvxTC` zOwnkfXfX2dmjLtI{NEuw*mxQGZKVF5ga7qSuYg_wy#jg#^a|(|&?}%+ApZ_5a^|U+C-q_4WT>TE?ue|G)nauDrhf|3Mujef|G@A^88p_5XPp z8Y-Ee|1al%#HqK`%j$&MuZGl2)uc=-S71edOsQ6~6q|fozAB%Q56d<3a@j2XO8P#m z$KNZJNa^Bd;!nlziqD8$;wCW>R@VQM@S1Q`XcpEAGXHCc)i=EYdIj_f=oQc_pjSYz zfL?+B?Fs~klU#;x+{K8S%s82G-P7(FA%k#Bmz}|3FUxv4qdMcHC(qMP4#Nk`cLawM zS=Q+cXGXQh>&YXBAge7noWQb91#e_H-S<3Rau9Bo0a+fFH4!`=yy2d5-y<=|Y7Y*( zS=LLzrr>G!1@{y=fZFG?TrBIwU~;g@J?g$d2H@yBg2PUh^#Z;XOm;`zqhvqaBnK+t zU|Hk%ReZ~x?~amwc*d^au$^T+kK6I9Zl61!>_c7iN3XK1lNiG-hGyM90-qyC(crMc zvYt!7jj`*7YnJRq(fQ+&S=Lzk>*=>$7hN|j5C?GR)MEbE!T-SlOyB3BRTg{;otFwe3^18)TGx&p2u(gUv%vTQ8t z=|F$r4dAqc$Su$g6z z_}}v1bB;SNktn=I$TG34C;Wr{x14>>anc1@@EVOQ>$rb|f6!U#>?56Uu0oc9Weumz z`8PN-oTa1#vfx}rEbCa>57OqGoHK*8!}-;Xhq*6*7<}R0V@K0KbKG)rqz$6l@ll5J zcv@lFF~?=cEdp;?v3R9G3|gaIdDv7|u}Ydub*|zhi<#Ag2@$F`P$J zM^oQ(gdF{(6_rBHqYURrYI*9YBhwKga91c)j*l>$M^dHKa)-&0Nt#hPrA57Mw~-pSZ%qjAXE^=IzfAF|SJm638k|Pl&v5o7pH2RydPcoU bs^C644R{~J(ULop&#Hs!8Bz)Nl4<-OI=H}F diff --git a/tests/ut/data/mindrecord/testGraphData/testdata b/tests/ut/data/mindrecord/testGraphData/testdata index 52359734692a4652a40c5a49d8b43c32f58e856c..ad97dbae98de88cd441cf63134099a059eb1c6aa 100644 GIT binary patch delta 41 xcmX>#oB7miX2#x)j9yH_$@#ejMXANbnfZC~#hF#9N>&D&J((8SZgM!)3jjy45AFZ} delta 25 hcmX>#oB7miX2yz*j9yHe6POm*PTs&Oy2;^GF93gE3E%(# diff --git a/tests/ut/python/dataset/test_graphdata_distributed.py b/tests/ut/python/dataset/test_graphdata_distributed.py new file mode 100644 index 00000000000..1e457e069b3 --- /dev/null +++ b/tests/ut/python/dataset/test_graphdata_distributed.py @@ -0,0 +1,125 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import random +import time +from multiprocessing import Process +import numpy as np +import mindspore.dataset as ds +from mindspore import log as logger + +DATASET_FILE = "../data/mindrecord/testGraphData/testdata" + + +def graphdata_startserver(): + """ + start graphdata server + """ + logger.info('test start server.\n') + ds.GraphData(DATASET_FILE, 1, 'server') + + +class RandomBatchedSampler(ds.Sampler): + # RandomBatchedSampler generate random sequence without replacement in a batched manner + def __init__(self, index_range, num_edges_per_sample): + super().__init__() + self.index_range = index_range + self.num_edges_per_sample = num_edges_per_sample + + def __iter__(self): + indices = [i+1 for i in range(self.index_range)] + # Reset random seed here if necessary + # random.seed(0) + random.shuffle(indices) + for i in range(0, self.index_range, self.num_edges_per_sample): + # Drop reminder + if i + self.num_edges_per_sample <= self.index_range: + yield indices[i: i + self.num_edges_per_sample] + + +class GNNGraphDataset(): + def __init__(self, g, batch_num): + self.g = g + self.batch_num = batch_num + + def __len__(self): + # Total sample size of GNN dataset + # In this case, the size should be total_num_edges/num_edges_per_sample + return self.g.graph_info()['edge_num'][0] // self.batch_num + + def __getitem__(self, index): + # index will be a list of indices yielded from RandomBatchedSampler + # Fetch edges/nodes/samples/features based on indices + nodes = self.g.get_nodes_from_edges(index.astype(np.int32)) + nodes = nodes[:, 0] + neg_nodes = self.g.get_neg_sampled_neighbors( + node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1) + nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[ + 2, 2], neighbor_types=[2, 1]) + neg_nodes_neighbors = self.g.get_sampled_neighbors( + node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], neighbor_types=[2, 2]) + nodes_neighbors_features = self.g.get_node_feature( + node_list=nodes_neighbors, feature_types=[2, 3]) + neg_neighbors_features = self.g.get_node_feature( + node_list=neg_nodes_neighbors, feature_types=[2, 3]) + return nodes_neighbors, neg_nodes_neighbors, nodes_neighbors_features[0], neg_neighbors_features[1] + + +def test_graphdata_distributed(): + """ + Test distributed + """ + logger.info('test distributed.\n') + + p1 = Process(target=graphdata_startserver) + p1.start() + time.sleep(2) + + g = ds.GraphData(DATASET_FILE, 1, 'client') + nodes = g.get_all_nodes(1) + assert nodes.tolist() == [101, 102, 103, 104, 105, 106, 107, 108, 109, 110] + row_tensor = g.get_node_feature(nodes.tolist(), [1, 2, 3]) + assert row_tensor[0].tolist() == [[0, 1, 0, 0, 0], [1, 0, 0, 0, 1], [0, 0, 1, 1, 0], [0, 0, 0, 0, 0], + [1, 1, 0, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0], [0, 0, 0, 1, 1], + [0, 1, 1, 0, 0], [0, 1, 0, 1, 0]] + assert row_tensor[2].tolist() == [1, 2, 3, 1, 4, 3, 5, 3, 5, 4] + + edges = g.get_all_edges(0) + assert edges.tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40] + features = g.get_edge_feature(edges, [1, 2]) + assert features[0].tolist() == [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, + 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0] + + batch_num = 2 + edge_num = g.graph_info()['edge_num'][0] + out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"] + dataset = ds.GeneratorDataset(source=GNNGraphDataset(g, batch_num), column_names=out_column_names, + sampler=RandomBatchedSampler(edge_num, batch_num), num_parallel_workers=4, + python_multiprocessing=False) + dataset = dataset.repeat(2) + itr = dataset.create_dict_iterator() + i = 0 + for data in itr: + assert data['neighbors'].shape == (2, 7) + assert data['neg_neighbors'].shape == (6, 7) + assert data['neighbors_features'].shape == (2, 7) + assert data['neg_neighbors_features'].shape == (6, 7) + i += 1 + assert i == 40 + + +if __name__ == '__main__': + test_graphdata_distributed()