forked from OSSInnovation/mindspore
Gnn data processing supports distributed scenarios
This commit is contained in:
parent
1ca715c7e7
commit
8ee4d8e92d
|
@ -15,7 +15,14 @@ include(${CMAKE_SOURCE_DIR}/cmake/external_libs/json.cmake)
|
||||||
include(${CMAKE_SOURCE_DIR}/cmake/dependency_securec.cmake)
|
include(${CMAKE_SOURCE_DIR}/cmake/dependency_securec.cmake)
|
||||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake)
|
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/protobuf.cmake)
|
||||||
|
|
||||||
|
SET(MS_BUILD_GRPC 0)
|
||||||
if (ENABLE_DEBUGGER OR ENABLE_SERVING OR ENABLE_TESTCASES)
|
if (ENABLE_DEBUGGER OR ENABLE_SERVING OR ENABLE_TESTCASES)
|
||||||
|
SET(MS_BUILD_GRPC 1)
|
||||||
|
endif()
|
||||||
|
if (ENABLE_MINDDATA AND NOT CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||||
|
SET(MS_BUILD_GRPC 1)
|
||||||
|
endif()
|
||||||
|
if ("${MS_BUILD_GRPC}")
|
||||||
# build dependencies of gRPC
|
# build dependencies of gRPC
|
||||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/absl.cmake)
|
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/absl.cmake)
|
||||||
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/c-ares.cmake)
|
include(${CMAKE_SOURCE_DIR}/cmake/external_libs/c-ares.cmake)
|
||||||
|
|
|
@ -83,6 +83,7 @@ endif()
|
||||||
if (ENABLE_TDTQUE)
|
if (ENABLE_TDTQUE)
|
||||||
add_dependencies(engine-tdt core)
|
add_dependencies(engine-tdt core)
|
||||||
endif ()
|
endif ()
|
||||||
|
|
||||||
################### Create _c_dataengine Library ######################
|
################### Create _c_dataengine Library ######################
|
||||||
set(submodules
|
set(submodules
|
||||||
$<TARGET_OBJECTS:core>
|
$<TARGET_OBJECTS:core>
|
||||||
|
@ -182,3 +183,7 @@ else()
|
||||||
set_target_properties(_c_dataengine PROPERTIES MACOSX_RPATH ON)
|
set_target_properties(_c_dataengine PROPERTIES MACOSX_RPATH ON)
|
||||||
endif ()
|
endif ()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if (NOT CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||||
|
target_link_libraries(_c_dataengine PRIVATE mindspore::grpc++)
|
||||||
|
endif()
|
|
@ -18,83 +18,103 @@
|
||||||
#include "pybind11/stl_bind.h"
|
#include "pybind11/stl_bind.h"
|
||||||
|
|
||||||
#include "minddata/dataset/api/python/pybind_register.h"
|
#include "minddata/dataset/api/python/pybind_register.h"
|
||||||
|
#include "minddata/dataset/engine/gnn/graph_data_client.h"
|
||||||
#include "minddata/dataset/engine/gnn/graph.h"
|
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
|
||||||
|
#include "minddata/dataset/engine/gnn/graph_data_server.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
PYBIND_REGISTER(
|
PYBIND_REGISTER(
|
||||||
Graph, 0, ([](const py::module *m) {
|
Graph, 0, ([](const py::module *m) {
|
||||||
(void)py::class_<gnn::Graph, std::shared_ptr<gnn::Graph>>(*m, "Graph")
|
(void)py::class_<gnn::GraphData, std::shared_ptr<gnn::GraphData>>(*m, "GraphDataClient")
|
||||||
.def(py::init([](std::string dataset_file, int32_t num_workers) {
|
.def(py::init([](const std::string &dataset_file, int32_t num_workers, const std::string &working_mode,
|
||||||
std::shared_ptr<gnn::Graph> g_out = std::make_shared<gnn::Graph>(dataset_file, num_workers);
|
const std::string &hostname, int32_t port) {
|
||||||
THROW_IF_ERROR(g_out->Init());
|
std::shared_ptr<gnn::GraphData> out;
|
||||||
return g_out;
|
if (working_mode == "local") {
|
||||||
|
out = std::make_shared<gnn::GraphDataImpl>(dataset_file, num_workers);
|
||||||
|
} else if (working_mode == "client") {
|
||||||
|
out = std::make_shared<gnn::GraphDataClient>(dataset_file, hostname, port);
|
||||||
|
}
|
||||||
|
THROW_IF_ERROR(out->Init());
|
||||||
|
return out;
|
||||||
}))
|
}))
|
||||||
.def("get_all_nodes",
|
.def("get_all_nodes",
|
||||||
[](gnn::Graph &g, gnn::NodeType node_type) {
|
[](gnn::GraphData &g, gnn::NodeType node_type) {
|
||||||
std::shared_ptr<Tensor> out;
|
std::shared_ptr<Tensor> out;
|
||||||
THROW_IF_ERROR(g.GetAllNodes(node_type, &out));
|
THROW_IF_ERROR(g.GetAllNodes(node_type, &out));
|
||||||
return out;
|
return out;
|
||||||
})
|
})
|
||||||
.def("get_all_edges",
|
.def("get_all_edges",
|
||||||
[](gnn::Graph &g, gnn::EdgeType edge_type) {
|
[](gnn::GraphData &g, gnn::EdgeType edge_type) {
|
||||||
std::shared_ptr<Tensor> out;
|
std::shared_ptr<Tensor> out;
|
||||||
THROW_IF_ERROR(g.GetAllEdges(edge_type, &out));
|
THROW_IF_ERROR(g.GetAllEdges(edge_type, &out));
|
||||||
return out;
|
return out;
|
||||||
})
|
})
|
||||||
.def("get_nodes_from_edges",
|
.def("get_nodes_from_edges",
|
||||||
[](gnn::Graph &g, std::vector<gnn::NodeIdType> edge_list) {
|
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> edge_list) {
|
||||||
std::shared_ptr<Tensor> out;
|
std::shared_ptr<Tensor> out;
|
||||||
THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out));
|
THROW_IF_ERROR(g.GetNodesFromEdges(edge_list, &out));
|
||||||
return out;
|
return out;
|
||||||
})
|
})
|
||||||
.def("get_all_neighbors",
|
.def("get_all_neighbors",
|
||||||
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type) {
|
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type) {
|
||||||
std::shared_ptr<Tensor> out;
|
std::shared_ptr<Tensor> out;
|
||||||
THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out));
|
THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out));
|
||||||
return out;
|
return out;
|
||||||
})
|
})
|
||||||
.def("get_sampled_neighbors",
|
.def("get_sampled_neighbors",
|
||||||
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums,
|
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeIdType> neighbor_nums,
|
||||||
std::vector<gnn::NodeType> neighbor_types) {
|
std::vector<gnn::NodeType> neighbor_types) {
|
||||||
std::shared_ptr<Tensor> out;
|
std::shared_ptr<Tensor> out;
|
||||||
THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out));
|
THROW_IF_ERROR(g.GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &out));
|
||||||
return out;
|
return out;
|
||||||
})
|
})
|
||||||
.def("get_neg_sampled_neighbors",
|
.def("get_neg_sampled_neighbors",
|
||||||
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeIdType neighbor_num,
|
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeIdType neighbor_num,
|
||||||
gnn::NodeType neg_neighbor_type) {
|
gnn::NodeType neg_neighbor_type) {
|
||||||
std::shared_ptr<Tensor> out;
|
std::shared_ptr<Tensor> out;
|
||||||
THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out));
|
THROW_IF_ERROR(g.GetNegSampledNeighbors(node_list, neighbor_num, neg_neighbor_type, &out));
|
||||||
return out;
|
return out;
|
||||||
})
|
})
|
||||||
.def("get_node_feature",
|
.def("get_node_feature",
|
||||||
[](gnn::Graph &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) {
|
[](gnn::GraphData &g, std::shared_ptr<Tensor> node_list, std::vector<gnn::FeatureType> feature_types) {
|
||||||
TensorRow out;
|
TensorRow out;
|
||||||
THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out));
|
THROW_IF_ERROR(g.GetNodeFeature(node_list, feature_types, &out));
|
||||||
return out.getRow();
|
return out.getRow();
|
||||||
})
|
})
|
||||||
.def("get_edge_feature",
|
.def("get_edge_feature",
|
||||||
[](gnn::Graph &g, std::shared_ptr<Tensor> edge_list, std::vector<gnn::FeatureType> feature_types) {
|
[](gnn::GraphData &g, std::shared_ptr<Tensor> edge_list, std::vector<gnn::FeatureType> feature_types) {
|
||||||
TensorRow out;
|
TensorRow out;
|
||||||
THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out));
|
THROW_IF_ERROR(g.GetEdgeFeature(edge_list, feature_types, &out));
|
||||||
return out.getRow();
|
return out.getRow();
|
||||||
})
|
})
|
||||||
.def("graph_info",
|
.def("graph_info",
|
||||||
[](gnn::Graph &g) {
|
[](gnn::GraphData &g) {
|
||||||
py::dict out;
|
py::dict out;
|
||||||
THROW_IF_ERROR(g.GraphInfo(&out));
|
THROW_IF_ERROR(g.GraphInfo(&out));
|
||||||
return out;
|
return out;
|
||||||
})
|
})
|
||||||
.def("random_walk",
|
.def("random_walk",
|
||||||
[](gnn::Graph &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeType> meta_path,
|
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, std::vector<gnn::NodeType> meta_path,
|
||||||
float step_home_param, float step_away_param, gnn::NodeIdType default_node) {
|
float step_home_param, float step_away_param, gnn::NodeIdType default_node) {
|
||||||
std::shared_ptr<Tensor> out;
|
std::shared_ptr<Tensor> out;
|
||||||
THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out));
|
THROW_IF_ERROR(g.RandomWalk(node_list, meta_path, step_home_param, step_away_param, default_node, &out));
|
||||||
return out;
|
return out;
|
||||||
});
|
})
|
||||||
|
.def("stop", [](gnn::GraphData &g) { THROW_IF_ERROR(g.Stop()); });
|
||||||
|
|
||||||
|
(void)py::class_<gnn::GraphDataServer, std::shared_ptr<gnn::GraphDataServer>>(*m, "GraphDataServer")
|
||||||
|
.def(py::init([](const std::string &dataset_file, int32_t num_workers, const std::string &hostname, int32_t port,
|
||||||
|
int32_t client_num, bool auto_shutdown) {
|
||||||
|
std::shared_ptr<gnn::GraphDataServer> out;
|
||||||
|
out =
|
||||||
|
std::make_shared<gnn::GraphDataServer>(dataset_file, num_workers, hostname, port, client_num, auto_shutdown);
|
||||||
|
THROW_IF_ERROR(out->Init());
|
||||||
|
return out;
|
||||||
|
}))
|
||||||
|
.def("stop", [](gnn::GraphDataServer &g) { THROW_IF_ERROR(g.Stop()); })
|
||||||
|
.def("is_stoped", [](gnn::GraphDataServer &g) { return g.IsStoped(); });
|
||||||
}));
|
}));
|
||||||
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
|
|
|
@ -1,9 +1,29 @@
|
||||||
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
|
||||||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||||
add_library(engine-gnn OBJECT
|
set(DATASET_ENGINE_GNN_SRC_FILES
|
||||||
graph.cc
|
graph_data_impl.cc
|
||||||
|
graph_data_client.cc
|
||||||
|
graph_data_server.cc
|
||||||
graph_loader.cc
|
graph_loader.cc
|
||||||
|
graph_feature_parser.cc
|
||||||
local_node.cc
|
local_node.cc
|
||||||
local_edge.cc
|
local_edge.cc
|
||||||
feature.cc
|
feature.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||||
|
add_library(engine-gnn OBJECT ${DATASET_ENGINE_GNN_SRC_FILES})
|
||||||
|
else()
|
||||||
|
set(DATASET_ENGINE_GNN_SRC_FILES
|
||||||
|
${DATASET_ENGINE_GNN_SRC_FILES}
|
||||||
|
tensor_proto.cc
|
||||||
|
grpc_async_server.cc
|
||||||
|
graph_data_service_impl.cc
|
||||||
|
graph_shared_memory.cc)
|
||||||
|
|
||||||
|
ms_protobuf_generate(TENSOR_PROTO_SRCS TENSOR_PROTO_HDRS "gnn_tensor.proto")
|
||||||
|
ms_grpc_generate(GNN_PROTO_SRCS GNN_PROTO_HDRS "gnn_graph_data.proto")
|
||||||
|
|
||||||
|
add_library(engine-gnn OBJECT ${DATASET_ENGINE_GNN_SRC_FILES} ${TENSOR_PROTO_SRCS} ${GNN_PROTO_SRCS})
|
||||||
|
add_dependencies(engine-gnn mindspore::protobuf)
|
||||||
|
endif()
|
||||||
|
|
|
@ -19,7 +19,8 @@ namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
namespace gnn {
|
namespace gnn {
|
||||||
|
|
||||||
Feature::Feature(FeatureType type_name, std::shared_ptr<Tensor> value) : type_name_(type_name), value_(value) {}
|
Feature::Feature(FeatureType type_name, std::shared_ptr<Tensor> value, bool is_shared_memory)
|
||||||
|
: type_name_(type_name), value_(value), is_shared_memory_(is_shared_memory) {}
|
||||||
|
|
||||||
} // namespace gnn
|
} // namespace gnn
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
|
|
|
@ -31,7 +31,7 @@ class Feature {
|
||||||
// Constructor
|
// Constructor
|
||||||
// @param FeatureType type_name - feature type
|
// @param FeatureType type_name - feature type
|
||||||
// @param std::shared_ptr<Tensor> value - feature value
|
// @param std::shared_ptr<Tensor> value - feature value
|
||||||
Feature(FeatureType type_name, std::shared_ptr<Tensor> value);
|
Feature(FeatureType type_name, std::shared_ptr<Tensor> value, bool is_shared_memory = false);
|
||||||
|
|
||||||
~Feature() = default;
|
~Feature() = default;
|
||||||
|
|
||||||
|
@ -45,6 +45,7 @@ class Feature {
|
||||||
private:
|
private:
|
||||||
FeatureType type_name_;
|
FeatureType type_name_;
|
||||||
std::shared_ptr<Tensor> value_;
|
std::shared_ptr<Tensor> value_;
|
||||||
|
bool is_shared_memory_;
|
||||||
};
|
};
|
||||||
} // namespace gnn
|
} // namespace gnn
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
|
|
|
@ -0,0 +1,103 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package mindspore.dataset;
|
||||||
|
|
||||||
|
import "gnn_tensor.proto";
|
||||||
|
|
||||||
|
message GnnClientRegisterRequestPb {
|
||||||
|
int32 pid = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GnnFeatureInfoPb {
|
||||||
|
int32 type = 1;
|
||||||
|
TensorPb feature = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GnnClientRegisterResponsePb {
|
||||||
|
string error_msg = 1;
|
||||||
|
string data_schema = 2;
|
||||||
|
int64 shared_memory_key = 3;
|
||||||
|
int64 shared_memory_size = 4;
|
||||||
|
repeated GnnFeatureInfoPb default_node_feature = 5;
|
||||||
|
repeated GnnFeatureInfoPb default_edge_feature = 6;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GnnClientUnRegisterRequestPb {
|
||||||
|
int32 pid = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GnnClientUnRegisterResponsePb {
|
||||||
|
string error_msg = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
enum GnnOpName {
|
||||||
|
GET_ALL_NODES = 0;
|
||||||
|
GET_ALL_EDGES = 1;
|
||||||
|
GET_NODES_FROM_EDGES = 2;
|
||||||
|
GET_ALL_NEIGHBORS = 3;
|
||||||
|
GET_SAMPLED_NEIGHBORS = 4;
|
||||||
|
GET_NEG_SAMPLED_NEIGHBORS = 5;
|
||||||
|
RANDOM_WALK = 6;
|
||||||
|
GET_NODE_FEATURE = 7;
|
||||||
|
GET_EDGE_FEATURE = 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GnnRandomWalkPb {
|
||||||
|
float p = 1;
|
||||||
|
float q = 2;
|
||||||
|
int32 default_id = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GnnGraphDataRequestPb {
|
||||||
|
GnnOpName op_name = 1;
|
||||||
|
repeated int32 id = 2; // node id or edge id
|
||||||
|
repeated int32 type = 3; //node type or edge type or neighbor type or feature type
|
||||||
|
repeated int32 number = 4; // samples number
|
||||||
|
TensorPb id_tensor = 5; // input ids ,node id or edge id
|
||||||
|
GnnRandomWalkPb random_walk = 6;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GnnGraphDataResponsePb {
|
||||||
|
string error_msg = 1;
|
||||||
|
repeated TensorPb result_data = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GnnMetaInfoRequestPb {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
message GnnNodeEdgeInfoPb {
|
||||||
|
int32 type = 1;
|
||||||
|
int32 num = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GnnMetaInfoResponsePb {
|
||||||
|
string error_msg = 1;
|
||||||
|
repeated GnnNodeEdgeInfoPb node_info = 2;
|
||||||
|
repeated GnnNodeEdgeInfoPb edge_info = 3;
|
||||||
|
repeated int32 node_feature_type = 4;
|
||||||
|
repeated int32 edge_feature_type = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
service GnnGraphData {
|
||||||
|
rpc ClientRegister(GnnClientRegisterRequestPb) returns (GnnClientRegisterResponsePb);
|
||||||
|
rpc ClientUnRegister(GnnClientUnRegisterRequestPb) returns (GnnClientUnRegisterResponsePb);
|
||||||
|
rpc GetGraphData(GnnGraphDataRequestPb) returns (GnnGraphDataResponsePb);
|
||||||
|
rpc GetMetaInfo(GnnMetaInfoRequestPb) returns (GnnMetaInfoResponsePb);
|
||||||
|
}
|
|
@ -0,0 +1,42 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package mindspore.dataset;
|
||||||
|
|
||||||
|
enum DataTypePb {
|
||||||
|
DE_PB_UNKNOWN = 0;
|
||||||
|
DE_PB_BOOL = 1;
|
||||||
|
DE_PB_INT8 = 2;
|
||||||
|
DE_PB_UINT8 = 3;
|
||||||
|
DE_PB_INT16 = 4;
|
||||||
|
DE_PB_UINT16 = 5;
|
||||||
|
DE_PB_INT32 = 6;
|
||||||
|
DE_PB_UINT32 = 7;
|
||||||
|
DE_PB_INT64 = 8;
|
||||||
|
DE_PB_UINT64 = 9;
|
||||||
|
DE_PB_FLOAT16 = 10;
|
||||||
|
DE_PB_FLOAT32 = 11;
|
||||||
|
DE_PB_FLOAT64 = 12;
|
||||||
|
DE_PB_STRING = 13;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TensorPb {
|
||||||
|
repeated int64 dims = 1; // tensor shape info
|
||||||
|
DataTypePb tensor_type = 2; // tensor content data type
|
||||||
|
bytes data = 3; // tensor data
|
||||||
|
}
|
|
@ -0,0 +1,134 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "minddata/dataset/core/tensor.h"
|
||||||
|
#include "minddata/dataset/core/tensor_row.h"
|
||||||
|
#include "minddata/dataset/engine/gnn/feature.h"
|
||||||
|
#include "minddata/dataset/engine/gnn/node.h"
|
||||||
|
#include "minddata/dataset/engine/gnn/edge.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
namespace gnn {
|
||||||
|
|
||||||
|
struct MetaInfo {
|
||||||
|
std::vector<NodeType> node_type;
|
||||||
|
std::vector<EdgeType> edge_type;
|
||||||
|
std::map<NodeType, NodeIdType> node_num;
|
||||||
|
std::map<EdgeType, EdgeIdType> edge_num;
|
||||||
|
std::vector<FeatureType> node_feature_type;
|
||||||
|
std::vector<FeatureType> edge_feature_type;
|
||||||
|
};
|
||||||
|
|
||||||
|
class GraphData {
|
||||||
|
public:
|
||||||
|
// Get all nodes from the graph.
|
||||||
|
// @param NodeType node_type - type of node
|
||||||
|
// @param std::shared_ptr<Tensor> *out - Returned nodes id
|
||||||
|
// @return Status - The error code return
|
||||||
|
virtual Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) = 0;
|
||||||
|
|
||||||
|
// Get all edges from the graph.
|
||||||
|
// @param NodeType edge_type - type of edge
|
||||||
|
// @param std::shared_ptr<Tensor> *out - Returned edge ids
|
||||||
|
// @return Status - The error code return
|
||||||
|
virtual Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) = 0;
|
||||||
|
|
||||||
|
// Get the node id from the edge.
|
||||||
|
// @param std::vector<EdgeIdType> edge_list - List of edges
|
||||||
|
// @param std::shared_ptr<Tensor> *out - Returned node ids
|
||||||
|
// @return Status - The error code return
|
||||||
|
virtual Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) = 0;
|
||||||
|
|
||||||
|
// All neighbors of the acquisition node.
|
||||||
|
// @param std::vector<NodeType> node_list - List of nodes
|
||||||
|
// @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported
|
||||||
|
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
|
||||||
|
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
|
||||||
|
// is not enough, fill in tensor as -1.
|
||||||
|
// @return Status - The error code return
|
||||||
|
virtual Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
|
||||||
|
std::shared_ptr<Tensor> *out) = 0;
|
||||||
|
|
||||||
|
// Get sampled neighbors.
|
||||||
|
// @param std::vector<NodeType> node_list - List of nodes
|
||||||
|
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
|
||||||
|
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
|
||||||
|
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
|
||||||
|
// @return Status - The error code return
|
||||||
|
virtual Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
|
||||||
|
const std::vector<NodeIdType> &neighbor_nums,
|
||||||
|
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) = 0;
|
||||||
|
|
||||||
|
// Get negative sampled neighbors.
|
||||||
|
// @param std::vector<NodeType> node_list - List of nodes
|
||||||
|
// @param NodeIdType samples_num - Number of neighbors sampled
|
||||||
|
// @param NodeType neg_neighbor_type - The type of negative neighbor.
|
||||||
|
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
|
||||||
|
// @return Status - The error code return
|
||||||
|
virtual Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
|
||||||
|
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) = 0;
|
||||||
|
|
||||||
|
// Node2vec random walk.
|
||||||
|
// @param std::vector<NodeIdType> node_list - List of nodes
|
||||||
|
// @param std::vector<NodeType> meta_path - node type of each step
|
||||||
|
// @param float step_home_param - return hyper parameter in node2vec algorithm
|
||||||
|
// @param float step_away_param - inout hyper parameter in node2vec algorithm
|
||||||
|
// @param NodeIdType default_node - default node id
|
||||||
|
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
|
||||||
|
// @return Status - The error code return
|
||||||
|
virtual Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
|
||||||
|
float step_home_param, float step_away_param, NodeIdType default_node,
|
||||||
|
std::shared_ptr<Tensor> *out) = 0;
|
||||||
|
|
||||||
|
// Get the feature of a node
|
||||||
|
// @param std::shared_ptr<Tensor> nodes - List of nodes
|
||||||
|
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
||||||
|
// does not exist.
|
||||||
|
// @param TensorRow *out - Returned features
|
||||||
|
// @return Status - The error code return
|
||||||
|
virtual Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
|
||||||
|
TensorRow *out) = 0;
|
||||||
|
|
||||||
|
// Get the feature of a edge
|
||||||
|
// @param std::shared_ptr<Tensor> edges - List of edges
|
||||||
|
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
||||||
|
// does not exist.
|
||||||
|
// @param Tensor *out - Returned features
|
||||||
|
// @return Status - The error code return
|
||||||
|
virtual Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
|
||||||
|
TensorRow *out) = 0;
|
||||||
|
|
||||||
|
// Return meta information to python layer
|
||||||
|
virtual Status GraphInfo(py::dict *out) = 0;
|
||||||
|
|
||||||
|
virtual Status Init() = 0;
|
||||||
|
|
||||||
|
virtual Status Stop() = 0;
|
||||||
|
};
|
||||||
|
} // namespace gnn
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_H_
|
|
@ -0,0 +1,589 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "minddata/dataset/engine/gnn/graph_data_client.h"
|
||||||
|
|
||||||
|
#include <unistd.h>
|
||||||
|
#include <functional>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
#include "grpcpp/grpcpp.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "minddata/dataset/core/data_type.h"
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
#include "minddata/dataset/engine/gnn/tensor_proto.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
namespace gnn {
|
||||||
|
|
||||||
|
GraphDataClient::GraphDataClient(const std::string &dataset_file, const std::string &hostname, int32_t port)
|
||||||
|
: dataset_file_(dataset_file),
|
||||||
|
host_(hostname),
|
||||||
|
port_(port),
|
||||||
|
pid_(0),
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
shared_memory_key_(-1),
|
||||||
|
shared_memory_size_(0),
|
||||||
|
graph_feature_parser_(nullptr),
|
||||||
|
graph_shared_memory_(nullptr),
|
||||||
|
#endif
|
||||||
|
registered_(false) {
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphDataClient::~GraphDataClient() { (void)Stop(); }
|
||||||
|
|
||||||
|
Status GraphDataClient::Init() {
|
||||||
|
#if defined(_WIN32) || defined(_WIN64)
|
||||||
|
RETURN_STATUS_UNEXPECTED("Graph data client is not supported in Windows OS");
|
||||||
|
#else
|
||||||
|
if (!registered_) {
|
||||||
|
std::string server_address;
|
||||||
|
server_address = host_ + ":" + std::to_string(port_);
|
||||||
|
MS_LOG(INFO) << "Graph data client starting. address:" << server_address;
|
||||||
|
pid_ = getpid();
|
||||||
|
grpc::ChannelArguments args;
|
||||||
|
args.SetMaxReceiveMessageSize(-1);
|
||||||
|
std::shared_ptr<grpc::Channel> channel =
|
||||||
|
grpc::CreateCustomChannel(server_address, grpc::InsecureChannelCredentials(), args);
|
||||||
|
stub_ = GnnGraphData::NewStub(channel);
|
||||||
|
Status status = RegisterToServer();
|
||||||
|
while (status.ToString().find("Initializing") != std::string::npos) {
|
||||||
|
MS_LOG(INFO) << "Graph data server is initializing, please wait.";
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(2000));
|
||||||
|
status = RegisterToServer();
|
||||||
|
}
|
||||||
|
RETURN_IF_NOT_OK(status);
|
||||||
|
MS_LOG(INFO) << "Graph data client successfully registered with server " << server_address;
|
||||||
|
}
|
||||||
|
RETURN_IF_NOT_OK(InitFeatureParser());
|
||||||
|
return Status::OK();
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataClient::Stop() {
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
if (registered_) {
|
||||||
|
UnRegisterToServer();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataClient::GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) {
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
GnnGraphDataRequestPb request;
|
||||||
|
GnnGraphDataResponsePb response;
|
||||||
|
request.set_op_name(GET_ALL_NODES);
|
||||||
|
request.add_type(static_cast<google::protobuf::int32>(node_type));
|
||||||
|
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
|
||||||
|
#endif
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataClient::GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) {
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
GnnGraphDataRequestPb request;
|
||||||
|
GnnGraphDataResponsePb response;
|
||||||
|
request.set_op_name(GET_ALL_EDGES);
|
||||||
|
request.add_type(static_cast<google::protobuf::int32>(edge_type));
|
||||||
|
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
|
||||||
|
#endif
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataClient::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) {
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
GnnGraphDataRequestPb request;
|
||||||
|
GnnGraphDataResponsePb response;
|
||||||
|
request.set_op_name(GET_NODES_FROM_EDGES);
|
||||||
|
for (const auto &edge_id : edge_list) {
|
||||||
|
request.add_id(static_cast<google::protobuf::int32>(edge_id));
|
||||||
|
}
|
||||||
|
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
|
||||||
|
#endif
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataClient::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
|
||||||
|
std::shared_ptr<Tensor> *out) {
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
GnnGraphDataRequestPb request;
|
||||||
|
GnnGraphDataResponsePb response;
|
||||||
|
request.set_op_name(GET_ALL_NEIGHBORS);
|
||||||
|
for (const auto &node_id : node_list) {
|
||||||
|
request.add_id(static_cast<google::protobuf::int32>(node_id));
|
||||||
|
}
|
||||||
|
request.add_type(static_cast<google::protobuf::int32>(neighbor_type));
|
||||||
|
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
|
||||||
|
#endif
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataClient::GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
|
||||||
|
const std::vector<NodeIdType> &neighbor_nums,
|
||||||
|
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) {
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
GnnGraphDataRequestPb request;
|
||||||
|
GnnGraphDataResponsePb response;
|
||||||
|
request.set_op_name(GET_SAMPLED_NEIGHBORS);
|
||||||
|
for (const auto &node_id : node_list) {
|
||||||
|
request.add_id(static_cast<google::protobuf::int32>(node_id));
|
||||||
|
}
|
||||||
|
for (const auto &num : neighbor_nums) {
|
||||||
|
request.add_number(static_cast<google::protobuf::int32>(num));
|
||||||
|
}
|
||||||
|
for (const auto &type : neighbor_types) {
|
||||||
|
request.add_type(static_cast<google::protobuf::int32>(type));
|
||||||
|
}
|
||||||
|
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
|
||||||
|
#endif
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataClient::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
|
||||||
|
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) {
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
GnnGraphDataRequestPb request;
|
||||||
|
GnnGraphDataResponsePb response;
|
||||||
|
request.set_op_name(GET_NEG_SAMPLED_NEIGHBORS);
|
||||||
|
for (const auto &node_id : node_list) {
|
||||||
|
request.add_id(static_cast<google::protobuf::int32>(node_id));
|
||||||
|
}
|
||||||
|
request.add_number(static_cast<google::protobuf::int32>(samples_num));
|
||||||
|
request.add_type(static_cast<google::protobuf::int32>(neg_neighbor_type));
|
||||||
|
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
|
||||||
|
#endif
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataClient::GraphDataClient::RandomWalk(const std::vector<NodeIdType> &node_list,
|
||||||
|
const std::vector<NodeType> &meta_path, float step_home_param,
|
||||||
|
float step_away_param, NodeIdType default_node,
|
||||||
|
std::shared_ptr<Tensor> *out) {
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
GnnGraphDataRequestPb request;
|
||||||
|
GnnGraphDataResponsePb response;
|
||||||
|
request.set_op_name(RANDOM_WALK);
|
||||||
|
for (const auto &node_id : node_list) {
|
||||||
|
request.add_id(static_cast<google::protobuf::int32>(node_id));
|
||||||
|
}
|
||||||
|
for (const auto &type : meta_path) {
|
||||||
|
request.add_type(static_cast<google::protobuf::int32>(type));
|
||||||
|
}
|
||||||
|
auto walk_param = request.mutable_random_walk();
|
||||||
|
walk_param->set_p(step_home_param);
|
||||||
|
walk_param->set_q(step_away_param);
|
||||||
|
walk_param->set_default_id(static_cast<google::protobuf::int32>(default_node));
|
||||||
|
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
|
||||||
|
#endif
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataClient::GetNodeFeature(const std::shared_ptr<Tensor> &nodes,
|
||||||
|
const std::vector<FeatureType> &feature_types, TensorRow *out) {
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
if (!nodes || nodes->Size() == 0) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Input nodes is empty");
|
||||||
|
}
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Input feature_types is empty");
|
||||||
|
|
||||||
|
GnnGraphDataRequestPb request;
|
||||||
|
GnnGraphDataResponsePb response;
|
||||||
|
request.set_op_name(GET_NODE_FEATURE);
|
||||||
|
for (const auto &type : feature_types) {
|
||||||
|
request.add_type(static_cast<google::protobuf::int32>(type));
|
||||||
|
}
|
||||||
|
RETURN_IF_NOT_OK(TensorToPb(nodes, request.mutable_id_tensor()));
|
||||||
|
RETURN_IF_NOT_OK(GetGraphData(request, &response));
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(feature_types.size() == response.result_data().size(),
|
||||||
|
"The number of feature types returned by the server is wrong");
|
||||||
|
if (response.result_data().size() > 0) {
|
||||||
|
size_t i = 0;
|
||||||
|
for (const auto &result : response.result_data()) {
|
||||||
|
std::shared_ptr<Tensor> tensor;
|
||||||
|
RETURN_IF_NOT_OK(PbToTensor(&result, &tensor));
|
||||||
|
std::shared_ptr<Tensor> fea_tensor;
|
||||||
|
RETURN_IF_NOT_OK(ParseNodeFeatureFromMemory(nodes, feature_types[i], tensor, &fea_tensor));
|
||||||
|
out->emplace_back(std::move(fea_tensor));
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
RETURN_STATUS_UNEXPECTED("RPC failed: The number of returned tensor is abnormal");
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataClient::GetEdgeFeature(const std::shared_ptr<Tensor> &edges,
|
||||||
|
const std::vector<FeatureType> &feature_types, TensorRow *out) {
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
if (!edges || edges->Size() == 0) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Input edges is empty");
|
||||||
|
}
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(!feature_types.empty(), "Input feature_types is empty");
|
||||||
|
|
||||||
|
GnnGraphDataRequestPb request;
|
||||||
|
GnnGraphDataResponsePb response;
|
||||||
|
request.set_op_name(GET_EDGE_FEATURE);
|
||||||
|
for (const auto &type : feature_types) {
|
||||||
|
request.add_type(static_cast<google::protobuf::int32>(type));
|
||||||
|
}
|
||||||
|
RETURN_IF_NOT_OK(TensorToPb(edges, request.mutable_id_tensor()));
|
||||||
|
RETURN_IF_NOT_OK(GetGraphData(request, &response));
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(feature_types.size() == response.result_data().size(),
|
||||||
|
"The number of feature types returned by the server is wrong");
|
||||||
|
if (response.result_data().size() > 0) {
|
||||||
|
size_t i = 0;
|
||||||
|
for (const auto &result : response.result_data()) {
|
||||||
|
std::shared_ptr<Tensor> tensor;
|
||||||
|
RETURN_IF_NOT_OK(PbToTensor(&result, &tensor));
|
||||||
|
std::shared_ptr<Tensor> fea_tensor;
|
||||||
|
RETURN_IF_NOT_OK(ParseEdgeFeatureFromMemory(edges, feature_types[i], tensor, &fea_tensor));
|
||||||
|
out->emplace_back(std::move(fea_tensor));
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
RETURN_STATUS_UNEXPECTED("RPC failed: The number of returned tensor is abnormal");
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataClient::GraphInfo(py::dict *out) {
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
RETURN_IF_NOT_OK(CheckPid());
|
||||||
|
void *tag;
|
||||||
|
bool ok;
|
||||||
|
grpc::Status status;
|
||||||
|
grpc::ClientContext ctx;
|
||||||
|
grpc::CompletionQueue cq;
|
||||||
|
GnnMetaInfoRequestPb request;
|
||||||
|
GnnMetaInfoResponsePb response;
|
||||||
|
// One minute timeout
|
||||||
|
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60);
|
||||||
|
ctx.set_deadline(deadline);
|
||||||
|
std::unique_ptr<grpc::ClientAsyncResponseReader<GnnMetaInfoResponsePb>> rpc(
|
||||||
|
stub_->PrepareAsyncGetMetaInfo(&ctx, request, &cq));
|
||||||
|
rpc->StartCall();
|
||||||
|
rpc->Finish(&response, &status, &response);
|
||||||
|
|
||||||
|
{
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
auto success = cq.Next(&tag, &ok);
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(success, "Expect successful");
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(tag == &response, "Expect the same tag");
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(ok, "Expect successful");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (status.ok()) {
|
||||||
|
if (response.error_msg() != "Success") {
|
||||||
|
RETURN_STATUS_UNEXPECTED(response.error_msg());
|
||||||
|
} else {
|
||||||
|
MetaInfo meta_info;
|
||||||
|
for (const auto &node : response.node_info()) {
|
||||||
|
meta_info.node_type.emplace_back(static_cast<NodeType>(node.type()));
|
||||||
|
meta_info.node_num[static_cast<NodeType>(node.type())] = static_cast<NodeIdType>(node.num());
|
||||||
|
}
|
||||||
|
for (const auto &edge : response.edge_info()) {
|
||||||
|
meta_info.edge_type.emplace_back(static_cast<EdgeType>(edge.type()));
|
||||||
|
meta_info.edge_num[static_cast<EdgeType>(edge.type())] = static_cast<EdgeIdType>(edge.num());
|
||||||
|
}
|
||||||
|
for (const auto &feature_type : response.node_feature_type()) {
|
||||||
|
meta_info.node_feature_type.emplace_back(static_cast<FeatureType>(feature_type));
|
||||||
|
}
|
||||||
|
for (const auto &feature_type : response.edge_feature_type()) {
|
||||||
|
meta_info.edge_feature_type.emplace_back(static_cast<FeatureType>(feature_type));
|
||||||
|
}
|
||||||
|
(*out)["node_type"] = py::cast(meta_info.node_type);
|
||||||
|
(*out)["edge_type"] = py::cast(meta_info.edge_type);
|
||||||
|
(*out)["node_num"] = py::cast(meta_info.node_num);
|
||||||
|
(*out)["edge_num"] = py::cast(meta_info.edge_num);
|
||||||
|
(*out)["node_feature_type"] = py::cast(meta_info.node_feature_type);
|
||||||
|
(*out)["edge_feature_type"] = py::cast(meta_info.edge_feature_type);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto error_code = status.error_code();
|
||||||
|
RETURN_STATUS_UNEXPECTED(status.error_message() + ". GRPC Code " + std::to_string(error_code));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
Status GraphDataClient::GetGraphData(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response) {
|
||||||
|
RETURN_IF_NOT_OK(CheckPid());
|
||||||
|
void *tag;
|
||||||
|
bool ok;
|
||||||
|
grpc::Status status;
|
||||||
|
grpc::ClientContext ctx;
|
||||||
|
grpc::CompletionQueue cq;
|
||||||
|
// One minute timeout
|
||||||
|
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60);
|
||||||
|
ctx.set_deadline(deadline);
|
||||||
|
std::unique_ptr<grpc::ClientAsyncResponseReader<GnnGraphDataResponsePb>> rpc(
|
||||||
|
stub_->PrepareAsyncGetGraphData(&ctx, request, &cq));
|
||||||
|
rpc->StartCall();
|
||||||
|
rpc->Finish(response, &status, response);
|
||||||
|
|
||||||
|
{
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
auto success = cq.Next(&tag, &ok);
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(success, "Expect successful");
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(tag == response, "Expect the same tag");
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(ok, "Expect successful");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (status.ok()) {
|
||||||
|
if (response->error_msg() != "Success") {
|
||||||
|
RETURN_STATUS_UNEXPECTED(response->error_msg());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto error_code = status.error_code();
|
||||||
|
RETURN_STATUS_UNEXPECTED(status.error_message() + ". GRPC Code " + std::to_string(error_code));
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataClient::GetGraphDataTensor(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response,
|
||||||
|
std::shared_ptr<Tensor> *out) {
|
||||||
|
RETURN_IF_NOT_OK(GetGraphData(request, response));
|
||||||
|
if (1 == response->result_data().size()) {
|
||||||
|
const TensorPb &result = response->result_data()[0];
|
||||||
|
std::shared_ptr<Tensor> tensor;
|
||||||
|
RETURN_IF_NOT_OK(PbToTensor(&result, &tensor));
|
||||||
|
*out = std::move(tensor);
|
||||||
|
} else {
|
||||||
|
RETURN_STATUS_UNEXPECTED("RPC failed: The number of returned tensor is abnormal");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataClient::ParseNodeFeatureFromMemory(const std::shared_ptr<Tensor> &nodes, FeatureType feature_type,
|
||||||
|
const std::shared_ptr<Tensor> &memory_tensor,
|
||||||
|
std::shared_ptr<Tensor> *out) {
|
||||||
|
std::shared_ptr<Tensor> default_feature;
|
||||||
|
// If no feature can be obtained, fill in the default value
|
||||||
|
RETURN_IF_NOT_OK(GetNodeDefaultFeature(feature_type, &default_feature));
|
||||||
|
TensorShape shape(default_feature->shape());
|
||||||
|
auto shape_vec = nodes->shape().AsVector();
|
||||||
|
dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies<dsize_t>());
|
||||||
|
shape = shape.PrependDim(size);
|
||||||
|
std::shared_ptr<Tensor> fea_tensor;
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, default_feature->type(), &fea_tensor));
|
||||||
|
|
||||||
|
dsize_t index = 0;
|
||||||
|
auto fea_addr_itr = memory_tensor->begin<int64_t>();
|
||||||
|
for (auto node_itr = nodes->begin<NodeIdType>(); node_itr != nodes->end<NodeIdType>(); ++node_itr) {
|
||||||
|
int64_t offset = *fea_addr_itr;
|
||||||
|
fea_addr_itr++;
|
||||||
|
int64_t len = *fea_addr_itr;
|
||||||
|
fea_addr_itr++;
|
||||||
|
if (*node_itr == kDefaultNodeId || offset < 0 || len <= 0) {
|
||||||
|
RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, default_feature));
|
||||||
|
} else {
|
||||||
|
uchar *start_addr_of_index = nullptr;
|
||||||
|
TensorShape remaining({-1});
|
||||||
|
RETURN_IF_NOT_OK(fea_tensor->StartAddrOfIndex({index}, &start_addr_of_index, &remaining));
|
||||||
|
RETURN_IF_NOT_OK(graph_shared_memory_->GetData(start_addr_of_index, len, offset, len));
|
||||||
|
}
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorShape reshape(nodes->shape());
|
||||||
|
for (auto s : default_feature->shape().AsVector()) {
|
||||||
|
reshape = reshape.AppendDim(s);
|
||||||
|
}
|
||||||
|
RETURN_IF_NOT_OK(fea_tensor->Reshape(reshape));
|
||||||
|
fea_tensor->Squeeze();
|
||||||
|
|
||||||
|
*out = std::move(fea_tensor);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataClient::ParseEdgeFeatureFromMemory(const std::shared_ptr<Tensor> &edges, FeatureType feature_type,
|
||||||
|
const std::shared_ptr<Tensor> &memory_tensor,
|
||||||
|
std::shared_ptr<Tensor> *out) {
|
||||||
|
std::shared_ptr<Tensor> default_feature;
|
||||||
|
// If no feature can be obtained, fill in the default value
|
||||||
|
RETURN_IF_NOT_OK(GetEdgeDefaultFeature(feature_type, &default_feature));
|
||||||
|
TensorShape shape(default_feature->shape());
|
||||||
|
auto shape_vec = edges->shape().AsVector();
|
||||||
|
dsize_t size = std::accumulate(shape_vec.begin(), shape_vec.end(), 1, std::multiplies<dsize_t>());
|
||||||
|
shape = shape.PrependDim(size);
|
||||||
|
std::shared_ptr<Tensor> fea_tensor;
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, default_feature->type(), &fea_tensor));
|
||||||
|
|
||||||
|
dsize_t index = 0;
|
||||||
|
auto fea_addr_itr = memory_tensor->begin<int64_t>();
|
||||||
|
for (auto edge_itr = edges->begin<EdgeIdType>(); edge_itr != edges->end<EdgeIdType>(); ++edge_itr) {
|
||||||
|
int64_t offset = *fea_addr_itr;
|
||||||
|
fea_addr_itr++;
|
||||||
|
int64_t len = *fea_addr_itr;
|
||||||
|
fea_addr_itr++;
|
||||||
|
if (offset < 0 || len <= 0) {
|
||||||
|
RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, default_feature));
|
||||||
|
} else {
|
||||||
|
uchar *start_addr_of_index = nullptr;
|
||||||
|
TensorShape remaining({-1});
|
||||||
|
RETURN_IF_NOT_OK(fea_tensor->StartAddrOfIndex({index}, &start_addr_of_index, &remaining));
|
||||||
|
RETURN_IF_NOT_OK(graph_shared_memory_->GetData(start_addr_of_index, len, offset, len));
|
||||||
|
}
|
||||||
|
index++;
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorShape reshape(edges->shape());
|
||||||
|
for (auto s : default_feature->shape().AsVector()) {
|
||||||
|
reshape = reshape.AppendDim(s);
|
||||||
|
}
|
||||||
|
RETURN_IF_NOT_OK(fea_tensor->Reshape(reshape));
|
||||||
|
fea_tensor->Squeeze();
|
||||||
|
|
||||||
|
*out = std::move(fea_tensor);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataClient::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature) {
|
||||||
|
auto itr = default_node_feature_map_.find(feature_type);
|
||||||
|
if (itr == default_node_feature_map_.end()) {
|
||||||
|
std::string err_msg = "Invalid feature type:" + std::to_string(feature_type);
|
||||||
|
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||||
|
} else {
|
||||||
|
*out_feature = itr->second;
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataClient::GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature) {
|
||||||
|
auto itr = default_edge_feature_map_.find(feature_type);
|
||||||
|
if (itr == default_edge_feature_map_.end()) {
|
||||||
|
std::string err_msg = "Invalid feature type:" + std::to_string(feature_type);
|
||||||
|
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||||
|
} else {
|
||||||
|
*out_feature = itr->second;
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataClient::RegisterToServer() {
|
||||||
|
RETURN_IF_NOT_OK(CheckPid());
|
||||||
|
void *tag;
|
||||||
|
bool ok;
|
||||||
|
grpc::Status status;
|
||||||
|
grpc::ClientContext ctx;
|
||||||
|
grpc::CompletionQueue cq;
|
||||||
|
GnnClientRegisterRequestPb request;
|
||||||
|
GnnClientRegisterResponsePb response;
|
||||||
|
request.set_pid(static_cast<google::protobuf::int32>(pid_));
|
||||||
|
// One minute timeout
|
||||||
|
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60);
|
||||||
|
ctx.set_deadline(deadline);
|
||||||
|
std::unique_ptr<grpc::ClientAsyncResponseReader<GnnClientRegisterResponsePb>> rpc(
|
||||||
|
stub_->PrepareAsyncClientRegister(&ctx, request, &cq));
|
||||||
|
rpc->StartCall();
|
||||||
|
rpc->Finish(&response, &status, &response);
|
||||||
|
|
||||||
|
{
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
auto success = cq.Next(&tag, &ok);
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(success, "Expect successful");
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(tag == &response, "Expect the same tag");
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(ok, "Expect successful");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (status.ok()) {
|
||||||
|
if (response.error_msg() == "Success") {
|
||||||
|
registered_ = true;
|
||||||
|
data_schema_ = mindrecord::json::parse(response.data_schema());
|
||||||
|
shared_memory_key_ = static_cast<key_t>(response.shared_memory_key());
|
||||||
|
shared_memory_size_ = response.shared_memory_size();
|
||||||
|
MS_LOG(INFO) << "Register success, recv data_schema:" << response.data_schema();
|
||||||
|
for (auto feature_info : response.default_node_feature()) {
|
||||||
|
std::shared_ptr<Tensor> tensor;
|
||||||
|
RETURN_IF_NOT_OK(PbToTensor(&feature_info.feature(), &tensor));
|
||||||
|
default_node_feature_map_[feature_info.type()] = tensor;
|
||||||
|
}
|
||||||
|
for (auto feature_info : response.default_edge_feature()) {
|
||||||
|
std::shared_ptr<Tensor> tensor;
|
||||||
|
RETURN_IF_NOT_OK(PbToTensor(&feature_info.feature(), &tensor));
|
||||||
|
default_edge_feature_map_[feature_info.type()] = tensor;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
RETURN_STATUS_UNEXPECTED(response.error_msg());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto error_code = status.error_code();
|
||||||
|
RETURN_STATUS_UNEXPECTED(status.error_message() + ". GRPC Code " + std::to_string(error_code));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataClient::UnRegisterToServer() {
|
||||||
|
RETURN_IF_NOT_OK(CheckPid());
|
||||||
|
MS_LOG(INFO) << "Graph data client send unregistered to server ";
|
||||||
|
void *tag;
|
||||||
|
bool ok;
|
||||||
|
grpc::Status status;
|
||||||
|
grpc::ClientContext ctx;
|
||||||
|
grpc::CompletionQueue cq;
|
||||||
|
GnnClientUnRegisterRequestPb request;
|
||||||
|
GnnClientUnRegisterResponsePb response;
|
||||||
|
request.set_pid(static_cast<google::protobuf::int32>(pid_));
|
||||||
|
// One minute timeout
|
||||||
|
auto deadline = std::chrono::system_clock::now() + std::chrono::seconds(60);
|
||||||
|
ctx.set_deadline(deadline);
|
||||||
|
std::unique_ptr<grpc::ClientAsyncResponseReader<GnnClientUnRegisterResponsePb>> rpc(
|
||||||
|
stub_->PrepareAsyncClientUnRegister(&ctx, request, &cq));
|
||||||
|
rpc->StartCall();
|
||||||
|
rpc->Finish(&response, &status, &response);
|
||||||
|
{
|
||||||
|
py::gil_scoped_release gil_release;
|
||||||
|
auto success = cq.Next(&tag, &ok);
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(success, "Expect successful");
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(tag == &response, "Expect the same tag");
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(ok, "Expect successful");
|
||||||
|
}
|
||||||
|
if (status.ok()) {
|
||||||
|
if (response.error_msg() == "Success") {
|
||||||
|
MS_LOG(INFO) << "Unregister success.";
|
||||||
|
registered_ = false;
|
||||||
|
} else {
|
||||||
|
RETURN_STATUS_UNEXPECTED(response.error_msg());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
auto error_code = status.error_code();
|
||||||
|
RETURN_STATUS_UNEXPECTED(status.error_message() + ". GRPC Code " + std::to_string(error_code));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataClient::InitFeatureParser() {
|
||||||
|
// get shared memory
|
||||||
|
graph_shared_memory_ = std::make_unique<GraphSharedMemory>(shared_memory_size_, shared_memory_key_);
|
||||||
|
RETURN_IF_NOT_OK(graph_shared_memory_->GetSharedMemory());
|
||||||
|
// build feature parser
|
||||||
|
graph_feature_parser_ = std::make_unique<GraphFeatureParser>(ShardColumn(data_schema_));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace gnn
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,185 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <map>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <vector>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
#include "proto/gnn_graph_data.grpc.pb.h"
|
||||||
|
#include "proto/gnn_graph_data.pb.h"
|
||||||
|
#endif
|
||||||
|
#include "minddata/dataset/engine/gnn/graph_data.h"
|
||||||
|
#include "minddata/dataset/engine/gnn/graph_feature_parser.h"
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
|
||||||
|
#endif
|
||||||
|
#include "minddata/mindrecord/include/common/shard_utils.h"
|
||||||
|
#include "minddata/mindrecord/include/shard_column.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
namespace gnn {
|
||||||
|
|
||||||
|
class GraphDataClient : public GraphData {
|
||||||
|
public:
|
||||||
|
// Constructor
|
||||||
|
// @param std::string dataset_file -
|
||||||
|
// @param int32_t num_workers - number of parallel threads
|
||||||
|
GraphDataClient(const std::string &dataset_file, const std::string &hostname, int32_t port);
|
||||||
|
|
||||||
|
~GraphDataClient();
|
||||||
|
|
||||||
|
Status Init() override;
|
||||||
|
|
||||||
|
Status Stop() override;
|
||||||
|
|
||||||
|
// Get all nodes from the graph.
|
||||||
|
// @param NodeType node_type - type of node
|
||||||
|
// @param std::shared_ptr<Tensor> *out - Returned nodes id
|
||||||
|
// @return Status - The error code return
|
||||||
|
Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override;
|
||||||
|
|
||||||
|
// Get all edges from the graph.
|
||||||
|
// @param NodeType edge_type - type of edge
|
||||||
|
// @param std::shared_ptr<Tensor> *out - Returned edge ids
|
||||||
|
// @return Status - The error code return
|
||||||
|
Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override;
|
||||||
|
|
||||||
|
// Get the node id from the edge.
|
||||||
|
// @param std::vector<EdgeIdType> edge_list - List of edges
|
||||||
|
// @param std::shared_ptr<Tensor> *out - Returned node ids
|
||||||
|
// @return Status - The error code return
|
||||||
|
Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override;
|
||||||
|
|
||||||
|
// All neighbors of the acquisition node.
|
||||||
|
// @param std::vector<NodeType> node_list - List of nodes
|
||||||
|
// @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported
|
||||||
|
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
|
||||||
|
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
|
||||||
|
// is not enough, fill in tensor as -1.
|
||||||
|
// @return Status - The error code return
|
||||||
|
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
|
||||||
|
std::shared_ptr<Tensor> *out) override;
|
||||||
|
|
||||||
|
// Get sampled neighbors.
|
||||||
|
// @param std::vector<NodeType> node_list - List of nodes
|
||||||
|
// @param std::vector<NodeIdType> neighbor_nums - Number of neighbors sampled per hop
|
||||||
|
// @param std::vector<NodeType> neighbor_types - Neighbor type sampled per hop
|
||||||
|
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
|
||||||
|
// @return Status - The error code return
|
||||||
|
Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
|
||||||
|
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override;
|
||||||
|
|
||||||
|
// Get negative sampled neighbors.
|
||||||
|
// @param std::vector<NodeType> node_list - List of nodes
|
||||||
|
// @param NodeIdType samples_num - Number of neighbors sampled
|
||||||
|
// @param NodeType neg_neighbor_type - The type of negative neighbor.
|
||||||
|
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
|
||||||
|
// @return Status - The error code return
|
||||||
|
Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
|
||||||
|
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override;
|
||||||
|
|
||||||
|
// Node2vec random walk.
|
||||||
|
// @param std::vector<NodeIdType> node_list - List of nodes
|
||||||
|
// @param std::vector<NodeType> meta_path - node type of each step
|
||||||
|
// @param float step_home_param - return hyper parameter in node2vec algorithm
|
||||||
|
// @param float step_away_param - inout hyper parameter in node2vec algorithm
|
||||||
|
// @param NodeIdType default_node - default node id
|
||||||
|
// @param std::shared_ptr<Tensor> *out - Returned nodes id in walk path
|
||||||
|
// @return Status - The error code return
|
||||||
|
Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
|
||||||
|
float step_home_param, float step_away_param, NodeIdType default_node,
|
||||||
|
std::shared_ptr<Tensor> *out) override;
|
||||||
|
|
||||||
|
// Get the feature of a node
|
||||||
|
// @param std::shared_ptr<Tensor> nodes - List of nodes
|
||||||
|
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
||||||
|
// does not exist.
|
||||||
|
// @param TensorRow *out - Returned features
|
||||||
|
// @return Status - The error code return
|
||||||
|
Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
|
||||||
|
TensorRow *out) override;
|
||||||
|
|
||||||
|
// Get the feature of a edge
|
||||||
|
// @param std::shared_ptr<Tensor> edges - List of edges
|
||||||
|
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
||||||
|
// does not exist.
|
||||||
|
// @param Tensor *out - Returned features
|
||||||
|
// @return Status - The error code return
|
||||||
|
Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
|
||||||
|
TensorRow *out) override;
|
||||||
|
|
||||||
|
// Return meta information to python layer
|
||||||
|
Status GraphInfo(py::dict *out) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
Status ParseNodeFeatureFromMemory(const std::shared_ptr<Tensor> &nodes, FeatureType feature_type,
|
||||||
|
const std::shared_ptr<Tensor> &memory_tensor, std::shared_ptr<Tensor> *out);
|
||||||
|
|
||||||
|
Status ParseEdgeFeatureFromMemory(const std::shared_ptr<Tensor> &edges, FeatureType feature_type,
|
||||||
|
const std::shared_ptr<Tensor> &memory_tensor, std::shared_ptr<Tensor> *out);
|
||||||
|
|
||||||
|
Status GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature);
|
||||||
|
|
||||||
|
Status GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Tensor> *out_feature);
|
||||||
|
|
||||||
|
Status GetGraphData(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response);
|
||||||
|
|
||||||
|
Status GetGraphDataTensor(const GnnGraphDataRequestPb &request, GnnGraphDataResponsePb *response,
|
||||||
|
std::shared_ptr<Tensor> *out);
|
||||||
|
|
||||||
|
Status RegisterToServer();
|
||||||
|
|
||||||
|
Status UnRegisterToServer();
|
||||||
|
|
||||||
|
Status InitFeatureParser();
|
||||||
|
|
||||||
|
Status CheckPid() {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(pid_ == getpid(),
|
||||||
|
"Multi-process mode is not supported, please change to use multi-thread");
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
std::string dataset_file_;
|
||||||
|
std::string host_;
|
||||||
|
int32_t port_;
|
||||||
|
int32_t pid_;
|
||||||
|
mindrecord::json data_schema_;
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
std::unique_ptr<GnnGraphData::Stub> stub_;
|
||||||
|
key_t shared_memory_key_;
|
||||||
|
int64_t shared_memory_size_;
|
||||||
|
std::unique_ptr<GraphFeatureParser> graph_feature_parser_;
|
||||||
|
std::unique_ptr<GraphSharedMemory> graph_shared_memory_;
|
||||||
|
std::unordered_map<FeatureType, std::shared_ptr<Tensor>> default_node_feature_map_;
|
||||||
|
std::unordered_map<FeatureType, std::shared_ptr<Tensor>> default_edge_feature_map_;
|
||||||
|
#endif
|
||||||
|
bool registered_;
|
||||||
|
};
|
||||||
|
} // namespace gnn
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_
|
|
@ -13,7 +13,7 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include "minddata/dataset/engine/gnn/graph.h"
|
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
@ -22,19 +22,25 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "minddata/dataset/core/tensor_shape.h"
|
#include "minddata/dataset/core/tensor_shape.h"
|
||||||
|
#include "minddata/dataset/engine/gnn/graph_loader.h"
|
||||||
#include "minddata/dataset/util/random.h"
|
#include "minddata/dataset/util/random.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
namespace gnn {
|
namespace gnn {
|
||||||
|
|
||||||
Graph::Graph(std::string dataset_file, int32_t num_workers)
|
GraphDataImpl::GraphDataImpl(std::string dataset_file, int32_t num_workers, bool server_mode)
|
||||||
: dataset_file_(dataset_file), num_workers_(num_workers), rnd_(GetRandomDevice()), random_walk_(this) {
|
: dataset_file_(dataset_file),
|
||||||
|
num_workers_(num_workers),
|
||||||
|
rnd_(GetRandomDevice()),
|
||||||
|
random_walk_(this),
|
||||||
|
server_mode_(server_mode) {
|
||||||
rnd_.seed(GetSeed());
|
rnd_.seed(GetSeed());
|
||||||
MS_LOG(INFO) << "num_workers:" << num_workers;
|
MS_LOG(INFO) << "num_workers:" << num_workers;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Graph::GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) {
|
GraphDataImpl::~GraphDataImpl() {}
|
||||||
|
|
||||||
|
Status GraphDataImpl::GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) {
|
||||||
auto itr = node_type_map_.find(node_type);
|
auto itr = node_type_map_.find(node_type);
|
||||||
if (itr == node_type_map_.end()) {
|
if (itr == node_type_map_.end()) {
|
||||||
std::string err_msg = "Invalid node type:" + std::to_string(node_type);
|
std::string err_msg = "Invalid node type:" + std::to_string(node_type);
|
||||||
|
@ -46,8 +52,8 @@ Status Graph::GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status Graph::CreateTensorByVector(const std::vector<std::vector<T>> &data, DataType type,
|
Status GraphDataImpl::CreateTensorByVector(const std::vector<std::vector<T>> &data, DataType type,
|
||||||
std::shared_ptr<Tensor> *out) {
|
std::shared_ptr<Tensor> *out) {
|
||||||
if (!type.IsCompatible<T>()) {
|
if (!type.IsCompatible<T>()) {
|
||||||
RETURN_STATUS_UNEXPECTED("Data type not compatible");
|
RETURN_STATUS_UNEXPECTED("Data type not compatible");
|
||||||
}
|
}
|
||||||
|
@ -72,7 +78,7 @@ Status Graph::CreateTensorByVector(const std::vector<std::vector<T>> &data, Data
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status Graph::ComplementVector(std::vector<std::vector<T>> *data, size_t max_size, T default_value) {
|
Status GraphDataImpl::ComplementVector(std::vector<std::vector<T>> *data, size_t max_size, T default_value) {
|
||||||
if (!data || data->empty()) {
|
if (!data || data->empty()) {
|
||||||
RETURN_STATUS_UNEXPECTED("Input data is empty");
|
RETURN_STATUS_UNEXPECTED("Input data is empty");
|
||||||
}
|
}
|
||||||
|
@ -89,7 +95,7 @@ Status Graph::ComplementVector(std::vector<std::vector<T>> *data, size_t max_siz
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Graph::GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) {
|
Status GraphDataImpl::GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) {
|
||||||
auto itr = edge_type_map_.find(edge_type);
|
auto itr = edge_type_map_.find(edge_type);
|
||||||
if (itr == edge_type_map_.end()) {
|
if (itr == edge_type_map_.end()) {
|
||||||
std::string err_msg = "Invalid edge type:" + std::to_string(edge_type);
|
std::string err_msg = "Invalid edge type:" + std::to_string(edge_type);
|
||||||
|
@ -100,7 +106,7 @@ Status Graph::GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Graph::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) {
|
Status GraphDataImpl::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) {
|
||||||
if (edge_list.empty()) {
|
if (edge_list.empty()) {
|
||||||
RETURN_STATUS_UNEXPECTED("Input edge_list is empty");
|
RETURN_STATUS_UNEXPECTED("Input edge_list is empty");
|
||||||
}
|
}
|
||||||
|
@ -122,8 +128,8 @@ Status Graph::GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::s
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
|
Status GraphDataImpl::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
|
||||||
std::shared_ptr<Tensor> *out) {
|
std::shared_ptr<Tensor> *out) {
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
|
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
|
||||||
RETURN_IF_NOT_OK(CheckNeighborType(neighbor_type));
|
RETURN_IF_NOT_OK(CheckNeighborType(neighbor_type));
|
||||||
|
|
||||||
|
@ -143,7 +149,7 @@ Status Graph::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Graph::CheckSamplesNum(NodeIdType samples_num) {
|
Status GraphDataImpl::CheckSamplesNum(NodeIdType samples_num) {
|
||||||
NodeIdType all_nodes_number =
|
NodeIdType all_nodes_number =
|
||||||
std::accumulate(node_type_map_.begin(), node_type_map_.end(), 0,
|
std::accumulate(node_type_map_.begin(), node_type_map_.end(), 0,
|
||||||
[](NodeIdType t1, const auto &t2) -> NodeIdType { return t1 + t2.second.size(); });
|
[](NodeIdType t1, const auto &t2) -> NodeIdType { return t1 + t2.second.size(); });
|
||||||
|
@ -155,7 +161,7 @@ Status Graph::CheckSamplesNum(NodeIdType samples_num) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Graph::CheckNeighborType(NodeType neighbor_type) {
|
Status GraphDataImpl::CheckNeighborType(NodeType neighbor_type) {
|
||||||
if (node_type_map_.find(neighbor_type) == node_type_map_.end()) {
|
if (node_type_map_.find(neighbor_type) == node_type_map_.end()) {
|
||||||
std::string err_msg = "Invalid neighbor type:" + std::to_string(neighbor_type);
|
std::string err_msg = "Invalid neighbor type:" + std::to_string(neighbor_type);
|
||||||
RETURN_STATUS_UNEXPECTED(err_msg);
|
RETURN_STATUS_UNEXPECTED(err_msg);
|
||||||
|
@ -163,9 +169,9 @@ Status Graph::CheckNeighborType(NodeType neighbor_type) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Graph::GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
|
Status GraphDataImpl::GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
|
||||||
const std::vector<NodeIdType> &neighbor_nums,
|
const std::vector<NodeIdType> &neighbor_nums,
|
||||||
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) {
|
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) {
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
|
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(),
|
CHECK_FAIL_RETURN_UNEXPECTED(neighbor_nums.size() == neighbor_types.size(),
|
||||||
"The sizes of neighbor_nums and neighbor_types are inconsistent.");
|
"The sizes of neighbor_nums and neighbor_types are inconsistent.");
|
||||||
|
@ -205,8 +211,9 @@ Status Graph::GetSampledNeighbors(const std::vector<NodeIdType> &node_list,
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Graph::NegativeSample(const std::vector<NodeIdType> &data, const std::unordered_set<NodeIdType> &exclude_data,
|
Status GraphDataImpl::NegativeSample(const std::vector<NodeIdType> &data,
|
||||||
int32_t samples_num, std::vector<NodeIdType> *out_samples) {
|
const std::unordered_set<NodeIdType> &exclude_data, int32_t samples_num,
|
||||||
|
std::vector<NodeIdType> *out_samples) {
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(!data.empty(), "Input data is empty.");
|
CHECK_FAIL_RETURN_UNEXPECTED(!data.empty(), "Input data is empty.");
|
||||||
std::vector<NodeIdType> shuffled_id(data.size());
|
std::vector<NodeIdType> shuffled_id(data.size());
|
||||||
std::iota(shuffled_id.begin(), shuffled_id.end(), 0);
|
std::iota(shuffled_id.begin(), shuffled_id.end(), 0);
|
||||||
|
@ -223,8 +230,8 @@ Status Graph::NegativeSample(const std::vector<NodeIdType> &data, const std::uno
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Graph::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
|
Status GraphDataImpl::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
|
||||||
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) {
|
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) {
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
|
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
|
||||||
RETURN_IF_NOT_OK(CheckSamplesNum(samples_num));
|
RETURN_IF_NOT_OK(CheckSamplesNum(samples_num));
|
||||||
RETURN_IF_NOT_OK(CheckNeighborType(neg_neighbor_type));
|
RETURN_IF_NOT_OK(CheckNeighborType(neg_neighbor_type));
|
||||||
|
@ -260,9 +267,9 @@ Status Graph::GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, N
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Graph::RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
|
Status GraphDataImpl::RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
|
||||||
float step_home_param, float step_away_param, NodeIdType default_node,
|
float step_home_param, float step_away_param, NodeIdType default_node,
|
||||||
std::shared_ptr<Tensor> *out) {
|
std::shared_ptr<Tensor> *out) {
|
||||||
RETURN_IF_NOT_OK(random_walk_.Build(node_list, meta_path, step_home_param, step_away_param, default_node));
|
RETURN_IF_NOT_OK(random_walk_.Build(node_list, meta_path, step_home_param, step_away_param, default_node));
|
||||||
std::vector<std::vector<NodeIdType>> walks;
|
std::vector<std::vector<NodeIdType>> walks;
|
||||||
RETURN_IF_NOT_OK(random_walk_.SimulateWalk(&walks));
|
RETURN_IF_NOT_OK(random_walk_.SimulateWalk(&walks));
|
||||||
|
@ -270,7 +277,7 @@ Status Graph::RandomWalk(const std::vector<NodeIdType> &node_list, const std::ve
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Graph::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) {
|
Status GraphDataImpl::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) {
|
||||||
auto itr = default_node_feature_map_.find(feature_type);
|
auto itr = default_node_feature_map_.find(feature_type);
|
||||||
if (itr == default_node_feature_map_.end()) {
|
if (itr == default_node_feature_map_.end()) {
|
||||||
std::string err_msg = "Invalid feature type:" + std::to_string(feature_type);
|
std::string err_msg = "Invalid feature type:" + std::to_string(feature_type);
|
||||||
|
@ -281,7 +288,7 @@ Status Graph::GetNodeDefaultFeature(FeatureType feature_type, std::shared_ptr<Fe
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Graph::GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) {
|
Status GraphDataImpl::GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Feature> *out_feature) {
|
||||||
auto itr = default_edge_feature_map_.find(feature_type);
|
auto itr = default_edge_feature_map_.find(feature_type);
|
||||||
if (itr == default_edge_feature_map_.end()) {
|
if (itr == default_edge_feature_map_.end()) {
|
||||||
std::string err_msg = "Invalid feature type:" + std::to_string(feature_type);
|
std::string err_msg = "Invalid feature type:" + std::to_string(feature_type);
|
||||||
|
@ -292,8 +299,8 @@ Status Graph::GetEdgeDefaultFeature(FeatureType feature_type, std::shared_ptr<Fe
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Graph::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
|
Status GraphDataImpl::GetNodeFeature(const std::shared_ptr<Tensor> &nodes,
|
||||||
TensorRow *out) {
|
const std::vector<FeatureType> &feature_types, TensorRow *out) {
|
||||||
if (!nodes || nodes->Size() == 0) {
|
if (!nodes || nodes->Size() == 0) {
|
||||||
RETURN_STATUS_UNEXPECTED("Input nodes is empty");
|
RETURN_STATUS_UNEXPECTED("Input nodes is empty");
|
||||||
}
|
}
|
||||||
|
@ -339,8 +346,49 @@ Status Graph::GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::ve
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Graph::GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
|
Status GraphDataImpl::GetNodeFeatureSharedMemory(const std::shared_ptr<Tensor> &nodes, FeatureType type,
|
||||||
TensorRow *out) {
|
std::shared_ptr<Tensor> *out) {
|
||||||
|
if (!nodes || nodes->Size() == 0) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Input nodes is empty");
|
||||||
|
}
|
||||||
|
TensorShape shape = nodes->shape().AppendDim(2);
|
||||||
|
std::shared_ptr<Tensor> fea_tensor;
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, DataType(DataType::DE_INT64), &fea_tensor));
|
||||||
|
|
||||||
|
auto out_fea_itr = fea_tensor->begin<int64_t>();
|
||||||
|
for (auto node_itr = nodes->begin<NodeIdType>(); node_itr != nodes->end<NodeIdType>(); ++node_itr) {
|
||||||
|
if (*node_itr == kDefaultNodeId) {
|
||||||
|
*out_fea_itr = -1;
|
||||||
|
++out_fea_itr;
|
||||||
|
*out_fea_itr = -1;
|
||||||
|
++out_fea_itr;
|
||||||
|
} else {
|
||||||
|
std::shared_ptr<Node> node;
|
||||||
|
RETURN_IF_NOT_OK(GetNodeByNodeId(*node_itr, &node));
|
||||||
|
std::shared_ptr<Feature> feature;
|
||||||
|
if (!node->GetFeatures(type, &feature).IsOk()) {
|
||||||
|
*out_fea_itr = -1;
|
||||||
|
++out_fea_itr;
|
||||||
|
*out_fea_itr = -1;
|
||||||
|
++out_fea_itr;
|
||||||
|
} else {
|
||||||
|
for (auto fea_itr = feature->Value()->begin<int64_t>(); fea_itr != feature->Value()->end<int64_t>();
|
||||||
|
++fea_itr) {
|
||||||
|
*out_fea_itr = *fea_itr;
|
||||||
|
++out_fea_itr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fea_tensor->Squeeze();
|
||||||
|
|
||||||
|
*out = std::move(fea_tensor);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataImpl::GetEdgeFeature(const std::shared_ptr<Tensor> &edges,
|
||||||
|
const std::vector<FeatureType> &feature_types, TensorRow *out) {
|
||||||
if (!edges || edges->Size() == 0) {
|
if (!edges || edges->Size() == 0) {
|
||||||
RETURN_STATUS_UNEXPECTED("Input edges is empty");
|
RETURN_STATUS_UNEXPECTED("Input edges is empty");
|
||||||
}
|
}
|
||||||
|
@ -382,12 +430,45 @@ Status Graph::GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::ve
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Graph::Init() {
|
Status GraphDataImpl::GetEdgeFeatureSharedMemory(const std::shared_ptr<Tensor> &edges, FeatureType type,
|
||||||
|
std::shared_ptr<Tensor> *out) {
|
||||||
|
if (!edges || edges->Size() == 0) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Input edges is empty");
|
||||||
|
}
|
||||||
|
TensorShape shape = edges->shape().AppendDim(2);
|
||||||
|
std::shared_ptr<Tensor> fea_tensor;
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateEmpty(shape, DataType(DataType::DE_INT64), &fea_tensor));
|
||||||
|
|
||||||
|
auto out_fea_itr = fea_tensor->begin<int64_t>();
|
||||||
|
for (auto edge_itr = edges->begin<EdgeIdType>(); edge_itr != edges->end<EdgeIdType>(); ++edge_itr) {
|
||||||
|
std::shared_ptr<Edge> edge;
|
||||||
|
RETURN_IF_NOT_OK(GetEdgeByEdgeId(*edge_itr, &edge));
|
||||||
|
std::shared_ptr<Feature> feature;
|
||||||
|
if (!edge->GetFeatures(type, &feature).IsOk()) {
|
||||||
|
*out_fea_itr = -1;
|
||||||
|
++out_fea_itr;
|
||||||
|
*out_fea_itr = -1;
|
||||||
|
++out_fea_itr;
|
||||||
|
} else {
|
||||||
|
for (auto fea_itr = feature->Value()->begin<int64_t>(); fea_itr != feature->Value()->end<int64_t>(); ++fea_itr) {
|
||||||
|
*out_fea_itr = *fea_itr;
|
||||||
|
++out_fea_itr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fea_tensor->Squeeze();
|
||||||
|
|
||||||
|
*out = std::move(fea_tensor);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataImpl::Init() {
|
||||||
RETURN_IF_NOT_OK(LoadNodeAndEdge());
|
RETURN_IF_NOT_OK(LoadNodeAndEdge());
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Graph::GetMetaInfo(MetaInfo *meta_info) {
|
Status GraphDataImpl::GetMetaInfo(MetaInfo *meta_info) {
|
||||||
meta_info->node_type.resize(node_type_map_.size());
|
meta_info->node_type.resize(node_type_map_.size());
|
||||||
std::transform(node_type_map_.begin(), node_type_map_.end(), meta_info->node_type.begin(),
|
std::transform(node_type_map_.begin(), node_type_map_.end(), meta_info->node_type.begin(),
|
||||||
[](auto itr) { return itr.first; });
|
[](auto itr) { return itr.first; });
|
||||||
|
@ -427,7 +508,7 @@ Status Graph::GetMetaInfo(MetaInfo *meta_info) {
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef ENABLE_PYTHON
|
#ifdef ENABLE_PYTHON
|
||||||
Status Graph::GraphInfo(py::dict *out) {
|
Status GraphDataImpl::GraphInfo(py::dict *out) {
|
||||||
MetaInfo meta_info;
|
MetaInfo meta_info;
|
||||||
RETURN_IF_NOT_OK(GetMetaInfo(&meta_info));
|
RETURN_IF_NOT_OK(GetMetaInfo(&meta_info));
|
||||||
(*out)["node_type"] = py::cast(meta_info.node_type);
|
(*out)["node_type"] = py::cast(meta_info.node_type);
|
||||||
|
@ -440,18 +521,16 @@ Status Graph::GraphInfo(py::dict *out) {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
Status Graph::LoadNodeAndEdge() {
|
Status GraphDataImpl::LoadNodeAndEdge() {
|
||||||
GraphLoader gl(dataset_file_, num_workers_);
|
GraphLoader gl(this, dataset_file_, num_workers_, server_mode_);
|
||||||
// ask graph_loader to load everything into memory
|
// ask graph_loader to load everything into memory
|
||||||
RETURN_IF_NOT_OK(gl.InitAndLoad());
|
RETURN_IF_NOT_OK(gl.InitAndLoad());
|
||||||
// get all maps
|
// get all maps
|
||||||
RETURN_IF_NOT_OK(gl.GetNodesAndEdges(&node_id_map_, &edge_id_map_, &node_type_map_, &edge_type_map_,
|
RETURN_IF_NOT_OK(gl.GetNodesAndEdges());
|
||||||
&node_feature_map_, &edge_feature_map_, &default_node_feature_map_,
|
|
||||||
&default_edge_feature_map_));
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node) {
|
Status GraphDataImpl::GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node) {
|
||||||
auto itr = node_id_map_.find(id);
|
auto itr = node_id_map_.find(id);
|
||||||
if (itr == node_id_map_.end()) {
|
if (itr == node_id_map_.end()) {
|
||||||
std::string err_msg = "Invalid node id:" + std::to_string(id);
|
std::string err_msg = "Invalid node id:" + std::to_string(id);
|
||||||
|
@ -462,7 +541,7 @@ Status Graph::GetNodeByNodeId(NodeIdType id, std::shared_ptr<Node> *node) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Graph::GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr<Edge> *edge) {
|
Status GraphDataImpl::GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr<Edge> *edge) {
|
||||||
auto itr = edge_id_map_.find(id);
|
auto itr = edge_id_map_.find(id);
|
||||||
if (itr == edge_id_map_.end()) {
|
if (itr == edge_id_map_.end()) {
|
||||||
std::string err_msg = "Invalid edge id:" + std::to_string(id);
|
std::string err_msg = "Invalid edge id:" + std::to_string(id);
|
||||||
|
@ -473,12 +552,13 @@ Status Graph::GetEdgeByEdgeId(EdgeIdType id, std::shared_ptr<Edge> *edge) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Graph::RandomWalkBase::RandomWalkBase(Graph *graph)
|
GraphDataImpl::RandomWalkBase::RandomWalkBase(GraphDataImpl *graph)
|
||||||
: graph_(graph), step_home_param_(1.0), step_away_param_(1.0), default_node_(-1), num_walks_(1), num_workers_(1) {}
|
: graph_(graph), step_home_param_(1.0), step_away_param_(1.0), default_node_(-1), num_walks_(1), num_workers_(1) {}
|
||||||
|
|
||||||
Status Graph::RandomWalkBase::Build(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
|
Status GraphDataImpl::RandomWalkBase::Build(const std::vector<NodeIdType> &node_list,
|
||||||
float step_home_param, float step_away_param, const NodeIdType default_node,
|
const std::vector<NodeType> &meta_path, float step_home_param,
|
||||||
int32_t num_walks, int32_t num_workers) {
|
float step_away_param, const NodeIdType default_node, int32_t num_walks,
|
||||||
|
int32_t num_workers) {
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
|
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
|
||||||
node_list_ = node_list;
|
node_list_ = node_list;
|
||||||
if (meta_path.empty() || meta_path.size() > kMaxNumWalks) {
|
if (meta_path.empty() || meta_path.size() > kMaxNumWalks) {
|
||||||
|
@ -516,7 +596,7 @@ Status Graph::RandomWalkBase::Build(const std::vector<NodeIdType> &node_list, co
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Graph::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::vector<NodeIdType> *walk_path) {
|
Status GraphDataImpl::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::vector<NodeIdType> *walk_path) {
|
||||||
// Simulate a random walk starting from start node.
|
// Simulate a random walk starting from start node.
|
||||||
auto walk = std::vector<NodeIdType>(1, start_node); // walk is an vector
|
auto walk = std::vector<NodeIdType>(1, start_node); // walk is an vector
|
||||||
// walk simulate
|
// walk simulate
|
||||||
|
@ -556,8 +636,8 @@ Status Graph::RandomWalkBase::Node2vecWalk(const NodeIdType &start_node, std::ve
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Graph::RandomWalkBase::SimulateWalk(std::vector<std::vector<NodeIdType>> *walks) {
|
Status GraphDataImpl::RandomWalkBase::SimulateWalk(std::vector<std::vector<NodeIdType>> *walks) {
|
||||||
for (int32_t i = 0; i < num_walks_; i++) {
|
for (int32_t i = 0; i < num_walks_; ++i) {
|
||||||
for (const auto &node : node_list_) {
|
for (const auto &node : node_list_) {
|
||||||
std::vector<NodeIdType> walk;
|
std::vector<NodeIdType> walk;
|
||||||
RETURN_IF_NOT_OK(Node2vecWalk(node, &walk));
|
RETURN_IF_NOT_OK(Node2vecWalk(node, &walk));
|
||||||
|
@ -567,8 +647,8 @@ Status Graph::RandomWalkBase::SimulateWalk(std::vector<std::vector<NodeIdType>>
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Graph::RandomWalkBase::GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type,
|
Status GraphDataImpl::RandomWalkBase::GetNodeProbability(const NodeIdType &node_id, const NodeType &node_type,
|
||||||
std::shared_ptr<StochasticIndex> *node_probability) {
|
std::shared_ptr<StochasticIndex> *node_probability) {
|
||||||
// Generate alias nodes
|
// Generate alias nodes
|
||||||
std::shared_ptr<Node> node;
|
std::shared_ptr<Node> node;
|
||||||
graph_->GetNodeByNodeId(node_id, &node);
|
graph_->GetNodeByNodeId(node_id, &node);
|
||||||
|
@ -581,8 +661,9 @@ Status Graph::RandomWalkBase::GetNodeProbability(const NodeIdType &node_id, cons
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Graph::RandomWalkBase::GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst, uint32_t meta_path_index,
|
Status GraphDataImpl::RandomWalkBase::GetEdgeProbability(const NodeIdType &src, const NodeIdType &dst,
|
||||||
std::shared_ptr<StochasticIndex> *edge_probability) {
|
uint32_t meta_path_index,
|
||||||
|
std::shared_ptr<StochasticIndex> *edge_probability) {
|
||||||
// Get the alias edge setup lists for a given edge.
|
// Get the alias edge setup lists for a given edge.
|
||||||
std::shared_ptr<Node> src_node;
|
std::shared_ptr<Node> src_node;
|
||||||
graph_->GetNodeByNodeId(src, &src_node);
|
graph_->GetNodeByNodeId(src, &src_node);
|
||||||
|
@ -616,7 +697,7 @@ Status Graph::RandomWalkBase::GetEdgeProbability(const NodeIdType &src, const No
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
StochasticIndex Graph::RandomWalkBase::GenerateProbability(const std::vector<float> &probability) {
|
StochasticIndex GraphDataImpl::RandomWalkBase::GenerateProbability(const std::vector<float> &probability) {
|
||||||
uint32_t K = probability.size();
|
uint32_t K = probability.size();
|
||||||
std::vector<int32_t> switch_to_large_index(K, 0);
|
std::vector<int32_t> switch_to_large_index(K, 0);
|
||||||
std::vector<float> weight(K, .0);
|
std::vector<float> weight(K, .0);
|
||||||
|
@ -644,7 +725,7 @@ StochasticIndex Graph::RandomWalkBase::GenerateProbability(const std::vector<flo
|
||||||
return StochasticIndex(switch_to_large_index, weight);
|
return StochasticIndex(switch_to_large_index, weight);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t Graph::RandomWalkBase::WalkToNextNode(const StochasticIndex &stochastic_index) {
|
uint32_t GraphDataImpl::RandomWalkBase::WalkToNextNode(const StochasticIndex &stochastic_index) {
|
||||||
auto switch_to_large_index = stochastic_index.first;
|
auto switch_to_large_index = stochastic_index.first;
|
||||||
auto weight = stochastic_index.second;
|
auto weight = stochastic_index.second;
|
||||||
const uint32_t size_of_index = switch_to_large_index.size();
|
const uint32_t size_of_index = switch_to_large_index.size();
|
||||||
|
@ -662,7 +743,7 @@ uint32_t Graph::RandomWalkBase::WalkToNextNode(const StochasticIndex &stochastic
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::vector<float> Graph::RandomWalkBase::Normalize(const std::vector<T> &non_normalized_probability) {
|
std::vector<float> GraphDataImpl::RandomWalkBase::Normalize(const std::vector<T> &non_normalized_probability) {
|
||||||
float sum_probability =
|
float sum_probability =
|
||||||
1.0 * std::accumulate(non_normalized_probability.begin(), non_normalized_probability.end(), 0);
|
1.0 * std::accumulate(non_normalized_probability.begin(), non_normalized_probability.end(), 0);
|
||||||
if (sum_probability < kGnnEpsilon) {
|
if (sum_probability < kGnnEpsilon) {
|
|
@ -13,8 +13,8 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_IMPL_H_
|
||||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_IMPL_H_
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
@ -25,13 +25,11 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "minddata/dataset/core/tensor.h"
|
#include "minddata/dataset/engine/gnn/graph_data.h"
|
||||||
#include "minddata/dataset/core/tensor_row.h"
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
#include "minddata/dataset/engine/gnn/graph_loader.h"
|
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
|
||||||
#include "minddata/dataset/engine/gnn/feature.h"
|
#endif
|
||||||
#include "minddata/dataset/engine/gnn/node.h"
|
#include "minddata/mindrecord/include/common/shard_utils.h"
|
||||||
#include "minddata/dataset/engine/gnn/edge.h"
|
|
||||||
#include "minddata/dataset/util/status.h"
|
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
@ -41,41 +39,32 @@ const float kGnnEpsilon = 0.0001;
|
||||||
const uint32_t kMaxNumWalks = 80;
|
const uint32_t kMaxNumWalks = 80;
|
||||||
using StochasticIndex = std::pair<std::vector<int32_t>, std::vector<float>>;
|
using StochasticIndex = std::pair<std::vector<int32_t>, std::vector<float>>;
|
||||||
|
|
||||||
struct MetaInfo {
|
class GraphDataImpl : public GraphData {
|
||||||
std::vector<NodeType> node_type;
|
|
||||||
std::vector<EdgeType> edge_type;
|
|
||||||
std::map<NodeType, NodeIdType> node_num;
|
|
||||||
std::map<EdgeType, EdgeIdType> edge_num;
|
|
||||||
std::vector<FeatureType> node_feature_type;
|
|
||||||
std::vector<FeatureType> edge_feature_type;
|
|
||||||
};
|
|
||||||
|
|
||||||
class Graph {
|
|
||||||
public:
|
public:
|
||||||
// Constructor
|
// Constructor
|
||||||
// @param std::string dataset_file -
|
// @param std::string dataset_file -
|
||||||
// @param int32_t num_workers - number of parallel threads
|
// @param int32_t num_workers - number of parallel threads
|
||||||
Graph(std::string dataset_file, int32_t num_workers);
|
GraphDataImpl(std::string dataset_file, int32_t num_workers, bool server_mode = false);
|
||||||
|
|
||||||
~Graph() = default;
|
~GraphDataImpl();
|
||||||
|
|
||||||
// Get all nodes from the graph.
|
// Get all nodes from the graph.
|
||||||
// @param NodeType node_type - type of node
|
// @param NodeType node_type - type of node
|
||||||
// @param std::shared_ptr<Tensor> *out - Returned nodes id
|
// @param std::shared_ptr<Tensor> *out - Returned nodes id
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out);
|
Status GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) override;
|
||||||
|
|
||||||
// Get all edges from the graph.
|
// Get all edges from the graph.
|
||||||
// @param NodeType edge_type - type of edge
|
// @param NodeType edge_type - type of edge
|
||||||
// @param std::shared_ptr<Tensor> *out - Returned edge ids
|
// @param std::shared_ptr<Tensor> *out - Returned edge ids
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out);
|
Status GetAllEdges(EdgeType edge_type, std::shared_ptr<Tensor> *out) override;
|
||||||
|
|
||||||
// Get the node id from the edge.
|
// Get the node id from the edge.
|
||||||
// @param std::vector<EdgeIdType> edge_list - List of edges
|
// @param std::vector<EdgeIdType> edge_list - List of edges
|
||||||
// @param std::shared_ptr<Tensor> *out - Returned node ids
|
// @param std::shared_ptr<Tensor> *out - Returned node ids
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out);
|
Status GetNodesFromEdges(const std::vector<EdgeIdType> &edge_list, std::shared_ptr<Tensor> *out) override;
|
||||||
|
|
||||||
// All neighbors of the acquisition node.
|
// All neighbors of the acquisition node.
|
||||||
// @param std::vector<NodeType> node_list - List of nodes
|
// @param std::vector<NodeType> node_list - List of nodes
|
||||||
|
@ -85,7 +74,7 @@ class Graph {
|
||||||
// is not enough, fill in tensor as -1.
|
// is not enough, fill in tensor as -1.
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
|
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
|
||||||
std::shared_ptr<Tensor> *out);
|
std::shared_ptr<Tensor> *out) override;
|
||||||
|
|
||||||
// Get sampled neighbors.
|
// Get sampled neighbors.
|
||||||
// @param std::vector<NodeType> node_list - List of nodes
|
// @param std::vector<NodeType> node_list - List of nodes
|
||||||
|
@ -94,7 +83,7 @@ class Graph {
|
||||||
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
|
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id.
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
|
Status GetSampledNeighbors(const std::vector<NodeIdType> &node_list, const std::vector<NodeIdType> &neighbor_nums,
|
||||||
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out);
|
const std::vector<NodeType> &neighbor_types, std::shared_ptr<Tensor> *out) override;
|
||||||
|
|
||||||
// Get negative sampled neighbors.
|
// Get negative sampled neighbors.
|
||||||
// @param std::vector<NodeType> node_list - List of nodes
|
// @param std::vector<NodeType> node_list - List of nodes
|
||||||
|
@ -103,7 +92,7 @@ class Graph {
|
||||||
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
|
// @param std::shared_ptr<Tensor> *out - Returned negative neighbor's id.
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
|
Status GetNegSampledNeighbors(const std::vector<NodeIdType> &node_list, NodeIdType samples_num,
|
||||||
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out);
|
NodeType neg_neighbor_type, std::shared_ptr<Tensor> *out) override;
|
||||||
|
|
||||||
// Node2vec random walk.
|
// Node2vec random walk.
|
||||||
// @param std::vector<NodeIdType> node_list - List of nodes
|
// @param std::vector<NodeIdType> node_list - List of nodes
|
||||||
|
@ -115,7 +104,7 @@ class Graph {
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
|
Status RandomWalk(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
|
||||||
float step_home_param, float step_away_param, NodeIdType default_node,
|
float step_home_param, float step_away_param, NodeIdType default_node,
|
||||||
std::shared_ptr<Tensor> *out);
|
std::shared_ptr<Tensor> *out) override;
|
||||||
|
|
||||||
// Get the feature of a node
|
// Get the feature of a node
|
||||||
// @param std::shared_ptr<Tensor> nodes - List of nodes
|
// @param std::shared_ptr<Tensor> nodes - List of nodes
|
||||||
|
@ -124,16 +113,22 @@ class Graph {
|
||||||
// @param TensorRow *out - Returned features
|
// @param TensorRow *out - Returned features
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
|
Status GetNodeFeature(const std::shared_ptr<Tensor> &nodes, const std::vector<FeatureType> &feature_types,
|
||||||
TensorRow *out);
|
TensorRow *out) override;
|
||||||
|
|
||||||
|
Status GetNodeFeatureSharedMemory(const std::shared_ptr<Tensor> &nodes, FeatureType type,
|
||||||
|
std::shared_ptr<Tensor> *out);
|
||||||
|
|
||||||
// Get the feature of a edge
|
// Get the feature of a edge
|
||||||
// @param std::shared_ptr<Tensor> edget - List of edges
|
// @param std::shared_ptr<Tensor> edges - List of edges
|
||||||
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
// @param std::vector<FeatureType> feature_types - Types of features, An error will be reported if the feature type
|
||||||
// does not exist.
|
// does not exist.
|
||||||
// @param Tensor *out - Returned features
|
// @param Tensor *out - Returned features
|
||||||
// @return Status - The error code return
|
// @return Status - The error code return
|
||||||
Status GetEdgeFeature(const std::shared_ptr<Tensor> &edget, const std::vector<FeatureType> &feature_types,
|
Status GetEdgeFeature(const std::shared_ptr<Tensor> &edges, const std::vector<FeatureType> &feature_types,
|
||||||
TensorRow *out);
|
TensorRow *out) override;
|
||||||
|
|
||||||
|
Status GetEdgeFeatureSharedMemory(const std::shared_ptr<Tensor> &edges, FeatureType type,
|
||||||
|
std::shared_ptr<Tensor> *out);
|
||||||
|
|
||||||
// Get meta information of graph
|
// Get meta information of graph
|
||||||
// @param MetaInfo *meta_info - Returned meta information
|
// @param MetaInfo *meta_info - Returned meta information
|
||||||
|
@ -142,15 +137,34 @@ class Graph {
|
||||||
|
|
||||||
#ifdef ENABLE_PYTHON
|
#ifdef ENABLE_PYTHON
|
||||||
// Return meta information to python layer
|
// Return meta information to python layer
|
||||||
Status GraphInfo(py::dict *out);
|
Status GraphInfo(py::dict *out) override;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
Status Init();
|
const std::unordered_map<FeatureType, std::shared_ptr<Feature>> *GetAllDefaultNodeFeatures() {
|
||||||
|
return &default_node_feature_map_;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::unordered_map<FeatureType, std::shared_ptr<Feature>> *GetAllDefaultEdgeFeatures() {
|
||||||
|
return &default_edge_feature_map_;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status Init() override;
|
||||||
|
|
||||||
|
Status Stop() override { return Status::OK(); }
|
||||||
|
|
||||||
|
std::string GetDataSchema() { return data_schema_.dump(); }
|
||||||
|
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
key_t GetSharedMemoryKey() { return graph_shared_memory_->memory_key(); }
|
||||||
|
|
||||||
|
int64_t GetSharedMemorySize() { return graph_shared_memory_->memory_size(); }
|
||||||
|
#endif
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
friend class GraphLoader;
|
||||||
class RandomWalkBase {
|
class RandomWalkBase {
|
||||||
public:
|
public:
|
||||||
explicit RandomWalkBase(Graph *graph);
|
explicit RandomWalkBase(GraphDataImpl *graph);
|
||||||
|
|
||||||
Status Build(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
|
Status Build(const std::vector<NodeIdType> &node_list, const std::vector<NodeType> &meta_path,
|
||||||
float step_home_param = 1.0, float step_away_param = 1.0, NodeIdType default_node = -1,
|
float step_home_param = 1.0, float step_away_param = 1.0, NodeIdType default_node = -1,
|
||||||
|
@ -176,7 +190,7 @@ class Graph {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
std::vector<float> Normalize(const std::vector<T> &non_normalized_probability);
|
std::vector<float> Normalize(const std::vector<T> &non_normalized_probability);
|
||||||
|
|
||||||
Graph *graph_;
|
GraphDataImpl *graph_;
|
||||||
std::vector<NodeIdType> node_list_;
|
std::vector<NodeIdType> node_list_;
|
||||||
std::vector<NodeType> meta_path_;
|
std::vector<NodeType> meta_path_;
|
||||||
float step_home_param_; // Return hyper parameter. Default is 1.0
|
float step_home_param_; // Return hyper parameter. Default is 1.0
|
||||||
|
@ -248,7 +262,11 @@ class Graph {
|
||||||
int32_t num_workers_; // The number of worker threads
|
int32_t num_workers_; // The number of worker threads
|
||||||
std::mt19937 rnd_;
|
std::mt19937 rnd_;
|
||||||
RandomWalkBase random_walk_;
|
RandomWalkBase random_walk_;
|
||||||
|
mindrecord::json data_schema_;
|
||||||
|
bool server_mode_;
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
std::unique_ptr<GraphSharedMemory> graph_shared_memory_;
|
||||||
|
#endif
|
||||||
std::unordered_map<NodeType, std::vector<NodeIdType>> node_type_map_;
|
std::unordered_map<NodeType, std::vector<NodeIdType>> node_type_map_;
|
||||||
std::unordered_map<NodeIdType, std::shared_ptr<Node>> node_id_map_;
|
std::unordered_map<NodeIdType, std::shared_ptr<Node>> node_id_map_;
|
||||||
|
|
||||||
|
@ -264,4 +282,4 @@ class Graph {
|
||||||
} // namespace gnn
|
} // namespace gnn
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_H_
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_IMPL_H_
|
|
@ -0,0 +1,133 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "minddata/dataset/engine/gnn/graph_data_server.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <functional>
|
||||||
|
#include <iterator>
|
||||||
|
#include <numeric>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "minddata/dataset/core/tensor_shape.h"
|
||||||
|
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
|
||||||
|
#include "minddata/dataset/util/random.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
namespace gnn {
|
||||||
|
|
||||||
|
GraphDataServer::GraphDataServer(const std::string &dataset_file, int32_t num_workers, const std::string &hostname,
|
||||||
|
int32_t port, int32_t client_num, bool auto_shutdown)
|
||||||
|
: dataset_file_(dataset_file),
|
||||||
|
num_workers_(num_workers),
|
||||||
|
client_num_(client_num),
|
||||||
|
max_connected_client_num_(0),
|
||||||
|
auto_shutdown_(auto_shutdown),
|
||||||
|
state_(kGdsUninit) {
|
||||||
|
tg_ = std::make_unique<TaskGroup>();
|
||||||
|
graph_data_impl_ = std::make_unique<GraphDataImpl>(dataset_file, num_workers, true);
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
service_impl_ = std::make_unique<GraphDataServiceImpl>(this, graph_data_impl_.get());
|
||||||
|
async_server_ = std::make_unique<GraphDataGrpcServer>(hostname, port, service_impl_.get());
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataServer::Init() {
|
||||||
|
#if defined(_WIN32) || defined(_WIN64)
|
||||||
|
RETURN_STATUS_UNEXPECTED("Graph data server is not supported in Windows OS");
|
||||||
|
#else
|
||||||
|
set_state(kGdsInitializing);
|
||||||
|
RETURN_IF_NOT_OK(async_server_->Run());
|
||||||
|
// RETURN_IF_NOT_OK(InitGraphDataImpl());
|
||||||
|
RETURN_IF_NOT_OK(tg_->CreateAsyncTask("init graph data impl", std::bind(&GraphDataServer::InitGraphDataImpl, this)));
|
||||||
|
for (int32_t i = 0; i < num_workers_; ++i) {
|
||||||
|
RETURN_IF_NOT_OK(
|
||||||
|
tg_->CreateAsyncTask("start async rpc service", std::bind(&GraphDataServer::StartAsyncRpcService, this)));
|
||||||
|
}
|
||||||
|
if (auto_shutdown_) {
|
||||||
|
RETURN_IF_NOT_OK(
|
||||||
|
tg_->CreateAsyncTask("judge auto shutdown server", std::bind(&GraphDataServer::JudgeAutoShutdownServer, this)));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataServer::InitGraphDataImpl() {
|
||||||
|
TaskManager::FindMe()->Post();
|
||||||
|
Status s = graph_data_impl_->Init();
|
||||||
|
if (s.IsOk()) {
|
||||||
|
set_state(kGdsRunning);
|
||||||
|
} else {
|
||||||
|
(void)Stop();
|
||||||
|
}
|
||||||
|
return s;
|
||||||
|
}
|
||||||
|
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
Status GraphDataServer::StartAsyncRpcService() {
|
||||||
|
TaskManager::FindMe()->Post();
|
||||||
|
RETURN_IF_NOT_OK(async_server_->HandleRequest());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
Status GraphDataServer::JudgeAutoShutdownServer() {
|
||||||
|
TaskManager::FindMe()->Post();
|
||||||
|
while (true) {
|
||||||
|
if (auto_shutdown_ && (max_connected_client_num_ >= client_num_) && (client_pid_.size() == 0)) {
|
||||||
|
MS_LOG(INFO) << "All clients have been unregister, automatically exit the server.";
|
||||||
|
RETURN_IF_NOT_OK(Stop());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (state_ == kGdsStopped) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataServer::Stop() {
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
async_server_->Stop();
|
||||||
|
#endif
|
||||||
|
set_state(kGdsStopped);
|
||||||
|
graph_data_impl_.reset();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataServer::ClientRegister(int32_t pid) {
|
||||||
|
std::unique_lock<std::mutex> lck(mutex_);
|
||||||
|
MS_LOG(INFO) << "client register pid:" << std::to_string(pid);
|
||||||
|
client_pid_.emplace(pid);
|
||||||
|
if (client_pid_.size() > max_connected_client_num_) {
|
||||||
|
max_connected_client_num_ = client_pid_.size();
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
Status GraphDataServer::ClientUnRegister(int32_t pid) {
|
||||||
|
std::unique_lock<std::mutex> lck(mutex_);
|
||||||
|
auto itr = client_pid_.find(pid);
|
||||||
|
if (itr != client_pid_.end()) {
|
||||||
|
client_pid_.erase(itr);
|
||||||
|
MS_LOG(INFO) << "client unregister pid:" << std::to_string(pid);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gnn
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,196 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <mutex>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
#include "grpcpp/grpcpp.h"
|
||||||
|
#include "minddata/dataset/engine/gnn/graph_data_service_impl.h"
|
||||||
|
#include "minddata/dataset/engine/gnn/grpc_async_server.h"
|
||||||
|
#endif
|
||||||
|
#include "minddata/dataset/util/task_manager.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
namespace gnn {
|
||||||
|
|
||||||
|
class GraphDataImpl;
|
||||||
|
|
||||||
|
class GraphDataServer {
|
||||||
|
public:
|
||||||
|
enum ServerState { kGdsUninit = 0, kGdsInitializing, kGdsRunning, kGdsStopped };
|
||||||
|
GraphDataServer(const std::string &dataset_file, int32_t num_workers, const std::string &hostname, int32_t port,
|
||||||
|
int32_t client_num, bool auto_shutdown);
|
||||||
|
~GraphDataServer() = default;
|
||||||
|
|
||||||
|
Status Init();
|
||||||
|
|
||||||
|
Status Stop();
|
||||||
|
|
||||||
|
Status ClientRegister(int32_t pid);
|
||||||
|
Status ClientUnRegister(int32_t pid);
|
||||||
|
|
||||||
|
enum ServerState state() { return state_; }
|
||||||
|
|
||||||
|
bool IsStoped() {
|
||||||
|
if (state_ == kGdsStopped) {
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void set_state(enum ServerState state) { state_ = state; }
|
||||||
|
|
||||||
|
Status InitGraphDataImpl();
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
Status StartAsyncRpcService();
|
||||||
|
#endif
|
||||||
|
Status JudgeAutoShutdownServer();
|
||||||
|
|
||||||
|
std::string dataset_file_;
|
||||||
|
int32_t num_workers_; // The number of worker threads
|
||||||
|
int32_t client_num_;
|
||||||
|
int32_t max_connected_client_num_;
|
||||||
|
bool auto_shutdown_;
|
||||||
|
enum ServerState state_;
|
||||||
|
std::unique_ptr<TaskGroup> tg_; // Class for worker management
|
||||||
|
std::unique_ptr<GraphDataImpl> graph_data_impl_;
|
||||||
|
std::unordered_set<int32_t> client_pid_;
|
||||||
|
std::mutex mutex_;
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
std::unique_ptr<GraphDataServiceImpl> service_impl_;
|
||||||
|
std::unique_ptr<GrpcAsyncServer> async_server_;
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
class UntypedCall {
|
||||||
|
public:
|
||||||
|
virtual ~UntypedCall() {}
|
||||||
|
|
||||||
|
virtual Status operator()() = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <class ServiceImpl, class AsyncService, class RequestMessage, class ResponseMessage>
|
||||||
|
class CallData : public UntypedCall {
|
||||||
|
public:
|
||||||
|
enum class STATE : int8_t { CREATE = 1, PROCESS = 2, FINISH = 3 };
|
||||||
|
using EnqueueFunction = void (AsyncService::*)(grpc::ServerContext *, RequestMessage *,
|
||||||
|
grpc::ServerAsyncResponseWriter<ResponseMessage> *,
|
||||||
|
grpc::CompletionQueue *, grpc::ServerCompletionQueue *, void *);
|
||||||
|
using HandleRequestFunction = grpc::Status (ServiceImpl::*)(grpc::ServerContext *, const RequestMessage *,
|
||||||
|
ResponseMessage *);
|
||||||
|
CallData(ServiceImpl *service_impl, AsyncService *async_service, grpc::ServerCompletionQueue *cq,
|
||||||
|
EnqueueFunction enqueue_function, HandleRequestFunction handle_request_function)
|
||||||
|
: status_(STATE::CREATE),
|
||||||
|
service_impl_(service_impl),
|
||||||
|
async_service_(async_service),
|
||||||
|
cq_(cq),
|
||||||
|
enqueue_function_(enqueue_function),
|
||||||
|
handle_request_function_(handle_request_function),
|
||||||
|
responder_(&ctx_) {}
|
||||||
|
|
||||||
|
~CallData() = default;
|
||||||
|
|
||||||
|
static Status EnqueueRequest(ServiceImpl *service_impl, AsyncService *async_service, grpc::ServerCompletionQueue *cq,
|
||||||
|
EnqueueFunction enqueue_function, HandleRequestFunction handle_request_function) {
|
||||||
|
auto call = new CallData<ServiceImpl, AsyncService, RequestMessage, ResponseMessage>(
|
||||||
|
service_impl, async_service, cq, enqueue_function, handle_request_function);
|
||||||
|
RETURN_IF_NOT_OK((*call)());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status operator()() {
|
||||||
|
if (status_ == STATE::CREATE) {
|
||||||
|
status_ = STATE::PROCESS;
|
||||||
|
(async_service_->*enqueue_function_)(&ctx_, &request_, &responder_, cq_, cq_, this);
|
||||||
|
} else if (status_ == STATE::PROCESS) {
|
||||||
|
EnqueueRequest(service_impl_, async_service_, cq_, enqueue_function_, handle_request_function_);
|
||||||
|
status_ = STATE::FINISH;
|
||||||
|
// new CallData(service_, cq_, this->s_type_);
|
||||||
|
grpc::Status s = (service_impl_->*handle_request_function_)(&ctx_, &request_, &response_);
|
||||||
|
responder_.Finish(response_, s, this);
|
||||||
|
} else {
|
||||||
|
GPR_ASSERT(status_ == STATE::FINISH);
|
||||||
|
delete this;
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
STATE status_;
|
||||||
|
ServiceImpl *service_impl_;
|
||||||
|
AsyncService *async_service_;
|
||||||
|
grpc::ServerCompletionQueue *cq_;
|
||||||
|
EnqueueFunction enqueue_function_;
|
||||||
|
HandleRequestFunction handle_request_function_;
|
||||||
|
grpc::ServerContext ctx_;
|
||||||
|
grpc::ServerAsyncResponseWriter<ResponseMessage> responder_;
|
||||||
|
RequestMessage request_;
|
||||||
|
ResponseMessage response_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#define ENQUEUE_REQUEST(service_impl, async_service, cq, method, request_msg, response_msg) \
|
||||||
|
do { \
|
||||||
|
Status s = \
|
||||||
|
CallData<gnn::GraphDataServiceImpl, GnnGraphData::AsyncService, request_msg, response_msg>::EnqueueRequest( \
|
||||||
|
service_impl, async_service, cq, &GnnGraphData::AsyncService::Request##method, \
|
||||||
|
&gnn::GraphDataServiceImpl::method); \
|
||||||
|
RETURN_IF_NOT_OK(s); \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
|
class GraphDataGrpcServer : public GrpcAsyncServer {
|
||||||
|
public:
|
||||||
|
GraphDataGrpcServer(const std::string &host, int32_t port, GraphDataServiceImpl *service_impl)
|
||||||
|
: GrpcAsyncServer(host, port), service_impl_(service_impl) {}
|
||||||
|
|
||||||
|
Status RegisterService(grpc::ServerBuilder *builder) {
|
||||||
|
builder->RegisterService(&svc_);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status EnqueueRequest() {
|
||||||
|
ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), ClientRegister, GnnClientRegisterRequestPb,
|
||||||
|
GnnClientRegisterResponsePb);
|
||||||
|
ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), ClientUnRegister, GnnClientUnRegisterRequestPb,
|
||||||
|
GnnClientUnRegisterResponsePb);
|
||||||
|
ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), GetGraphData, GnnGraphDataRequestPb, GnnGraphDataResponsePb);
|
||||||
|
ENQUEUE_REQUEST(service_impl_, &svc_, cq_.get(), GetMetaInfo, GnnMetaInfoRequestPb, GnnMetaInfoResponsePb);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ProcessRequest(void *tag) {
|
||||||
|
auto rq = static_cast<UntypedCall *>(tag);
|
||||||
|
RETURN_IF_NOT_OK((*rq)());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
GraphDataServiceImpl *service_impl_;
|
||||||
|
GnnGraphData::AsyncService svc_;
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
} // namespace gnn
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVER_H_
|
|
@ -0,0 +1,299 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "minddata/dataset/engine/gnn/graph_data_service_impl.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/engine/gnn/tensor_proto.h"
|
||||||
|
#include "minddata/dataset/engine/gnn/graph_data_server.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
namespace gnn {
|
||||||
|
|
||||||
|
using pFunction = Status (GraphDataServiceImpl::*)(const GnnGraphDataRequestPb *, GnnGraphDataResponsePb *);
|
||||||
|
static std::unordered_map<uint32_t, pFunction> g_get_graph_data_func_ = {
|
||||||
|
{GET_ALL_NODES, &GraphDataServiceImpl::GetAllNodes},
|
||||||
|
{GET_ALL_EDGES, &GraphDataServiceImpl::GetAllEdges},
|
||||||
|
{GET_NODES_FROM_EDGES, &GraphDataServiceImpl::GetNodesFromEdges},
|
||||||
|
{GET_ALL_NEIGHBORS, &GraphDataServiceImpl::GetAllNeighbors},
|
||||||
|
{GET_SAMPLED_NEIGHBORS, &GraphDataServiceImpl::GetSampledNeighbors},
|
||||||
|
{GET_NEG_SAMPLED_NEIGHBORS, &GraphDataServiceImpl::GetNegSampledNeighbors},
|
||||||
|
{RANDOM_WALK, &GraphDataServiceImpl::RandomWalk},
|
||||||
|
{GET_NODE_FEATURE, &GraphDataServiceImpl::GetNodeFeature},
|
||||||
|
{GET_EDGE_FEATURE, &GraphDataServiceImpl::GetEdgeFeature}};
|
||||||
|
|
||||||
|
GraphDataServiceImpl::GraphDataServiceImpl(GraphDataServer *server, GraphDataImpl *graph_data_impl)
|
||||||
|
: server_(server), graph_data_impl_(graph_data_impl) {}
|
||||||
|
|
||||||
|
Status GraphDataServiceImpl::FillDefaultFeature(GnnClientRegisterResponsePb *response) {
|
||||||
|
const auto default_node_features = graph_data_impl_->GetAllDefaultNodeFeatures();
|
||||||
|
for (const auto feature : *default_node_features) {
|
||||||
|
GnnFeatureInfoPb *feature_info = response->add_default_node_feature();
|
||||||
|
feature_info->set_type(feature.first);
|
||||||
|
RETURN_IF_NOT_OK(TensorToPb(feature.second->Value(), feature_info->mutable_feature()));
|
||||||
|
}
|
||||||
|
const auto default_edge_features = graph_data_impl_->GetAllDefaultEdgeFeatures();
|
||||||
|
for (const auto feature : *default_edge_features) {
|
||||||
|
GnnFeatureInfoPb *feature_info = response->add_default_edge_feature();
|
||||||
|
feature_info->set_type(feature.first);
|
||||||
|
RETURN_IF_NOT_OK(TensorToPb(feature.second->Value(), feature_info->mutable_feature()));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
grpc::Status GraphDataServiceImpl::ClientRegister(grpc::ServerContext *context,
|
||||||
|
const GnnClientRegisterRequestPb *request,
|
||||||
|
GnnClientRegisterResponsePb *response) {
|
||||||
|
Status s = server_->ClientRegister(request->pid());
|
||||||
|
if (s.IsOk()) {
|
||||||
|
switch (server_->state()) {
|
||||||
|
case GraphDataServer::kGdsUninit:
|
||||||
|
case GraphDataServer::kGdsInitializing:
|
||||||
|
response->set_error_msg("Initializing");
|
||||||
|
break;
|
||||||
|
case GraphDataServer::kGdsRunning:
|
||||||
|
response->set_error_msg("Success");
|
||||||
|
response->set_data_schema(graph_data_impl_->GetDataSchema());
|
||||||
|
response->set_shared_memory_key(graph_data_impl_->GetSharedMemoryKey());
|
||||||
|
response->set_shared_memory_size(graph_data_impl_->GetSharedMemorySize());
|
||||||
|
s = FillDefaultFeature(response);
|
||||||
|
if (!s.IsOk()) {
|
||||||
|
response->set_error_msg(s.ToString());
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case GraphDataServer::kGdsStopped:
|
||||||
|
response->set_error_msg("Stoped");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
response->set_error_msg(s.ToString());
|
||||||
|
}
|
||||||
|
return ::grpc::Status::OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
grpc::Status GraphDataServiceImpl::ClientUnRegister(grpc::ServerContext *context,
|
||||||
|
const GnnClientUnRegisterRequestPb *request,
|
||||||
|
GnnClientUnRegisterResponsePb *response) {
|
||||||
|
Status s = server_->ClientUnRegister(request->pid());
|
||||||
|
if (s.IsOk()) {
|
||||||
|
response->set_error_msg("Success");
|
||||||
|
} else {
|
||||||
|
response->set_error_msg(s.ToString());
|
||||||
|
}
|
||||||
|
return ::grpc::Status::OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
grpc::Status GraphDataServiceImpl::GetGraphData(grpc::ServerContext *context, const GnnGraphDataRequestPb *request,
|
||||||
|
GnnGraphDataResponsePb *response) {
|
||||||
|
// MS_LOG(INFO) << "#### receive GetGraphData:" << request->op_name();
|
||||||
|
Status s;
|
||||||
|
auto iter = g_get_graph_data_func_.find(request->op_name());
|
||||||
|
if (iter != g_get_graph_data_func_.end()) {
|
||||||
|
pFunction func = iter->second;
|
||||||
|
s = (this->*func)(request, response);
|
||||||
|
if (s.IsOk()) {
|
||||||
|
response->set_error_msg("Success");
|
||||||
|
} else {
|
||||||
|
response->set_error_msg(s.ToString());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
response->set_error_msg("Invalid op name.");
|
||||||
|
}
|
||||||
|
// MS_LOG(INFO) << "#### end receive GetGraphData:" << request->op_name();
|
||||||
|
return ::grpc::Status::OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
grpc::Status GraphDataServiceImpl::GetMetaInfo(grpc::ServerContext *context, const GnnMetaInfoRequestPb *request,
|
||||||
|
GnnMetaInfoResponsePb *response) {
|
||||||
|
MetaInfo meta_info;
|
||||||
|
Status s = graph_data_impl_->GetMetaInfo(&meta_info);
|
||||||
|
if (s.IsOk()) {
|
||||||
|
response->set_error_msg("Success");
|
||||||
|
for (const auto &type : meta_info.node_type) {
|
||||||
|
auto node_info = response->add_node_info();
|
||||||
|
node_info->set_type(static_cast<google::protobuf::int32>(type));
|
||||||
|
auto itr = meta_info.node_num.find(type);
|
||||||
|
if (itr != meta_info.node_num.end()) {
|
||||||
|
node_info->set_num(static_cast<google::protobuf::int32>(itr->second));
|
||||||
|
} else {
|
||||||
|
node_info->set_num(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const auto &type : meta_info.edge_type) {
|
||||||
|
auto edge_info = response->add_edge_info();
|
||||||
|
edge_info->set_type(static_cast<google::protobuf::int32>(type));
|
||||||
|
auto itr = meta_info.edge_num.find(type);
|
||||||
|
if (itr != meta_info.edge_num.end()) {
|
||||||
|
edge_info->set_num(static_cast<google::protobuf::int32>(itr->second));
|
||||||
|
} else {
|
||||||
|
edge_info->set_num(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (const auto &type : meta_info.node_feature_type) {
|
||||||
|
response->add_node_feature_type(static_cast<google::protobuf::int32>(type));
|
||||||
|
}
|
||||||
|
for (const auto &type : meta_info.edge_feature_type) {
|
||||||
|
response->add_edge_feature_type(static_cast<google::protobuf::int32>(type));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
response->set_error_msg(s.ToString());
|
||||||
|
}
|
||||||
|
return ::grpc::Status::OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataServiceImpl::GetAllNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of edge types is not 1");
|
||||||
|
|
||||||
|
std::shared_ptr<Tensor> tensor;
|
||||||
|
RETURN_IF_NOT_OK(graph_data_impl_->GetAllNodes(static_cast<NodeType>(request->type()[0]), &tensor));
|
||||||
|
TensorPb *result = response->add_result_data();
|
||||||
|
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataServiceImpl::GetAllEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of edge types is not 1");
|
||||||
|
|
||||||
|
std::shared_ptr<Tensor> tensor;
|
||||||
|
RETURN_IF_NOT_OK(graph_data_impl_->GetAllEdges(static_cast<EdgeType>(request->type()[0]), &tensor));
|
||||||
|
TensorPb *result = response->add_result_data();
|
||||||
|
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataServiceImpl::GetNodesFromEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input edge id is empty");
|
||||||
|
|
||||||
|
std::vector<EdgeIdType> edge_list;
|
||||||
|
edge_list.resize(request->id().size());
|
||||||
|
std::transform(request->id().begin(), request->id().end(), edge_list.begin(),
|
||||||
|
[](const google::protobuf::int32 id) { return static_cast<EdgeIdType>(id); });
|
||||||
|
std::shared_ptr<Tensor> tensor;
|
||||||
|
RETURN_IF_NOT_OK(graph_data_impl_->GetNodesFromEdges(edge_list, &tensor));
|
||||||
|
TensorPb *result = response->add_result_data();
|
||||||
|
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataServiceImpl::GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty");
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of edge types is not 1");
|
||||||
|
|
||||||
|
std::vector<NodeIdType> node_list;
|
||||||
|
node_list.resize(request->id().size());
|
||||||
|
std::transform(request->id().begin(), request->id().end(), node_list.begin(),
|
||||||
|
[](const google::protobuf::int32 id) { return static_cast<NodeIdType>(id); });
|
||||||
|
std::shared_ptr<Tensor> tensor;
|
||||||
|
RETURN_IF_NOT_OK(graph_data_impl_->GetAllNeighbors(node_list, static_cast<NodeType>(request->type()[0]), &tensor));
|
||||||
|
TensorPb *result = response->add_result_data();
|
||||||
|
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataServiceImpl::GetSampledNeighbors(const GnnGraphDataRequestPb *request,
|
||||||
|
GnnGraphDataResponsePb *response) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty");
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(request->number_size() > 0, "The input neighbor number is empty");
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() > 0, "The input neighbor type is empty");
|
||||||
|
|
||||||
|
std::vector<NodeIdType> node_list;
|
||||||
|
node_list.resize(request->id().size());
|
||||||
|
std::transform(request->id().begin(), request->id().end(), node_list.begin(),
|
||||||
|
[](const google::protobuf::int32 id) { return static_cast<NodeIdType>(id); });
|
||||||
|
std::vector<NodeIdType> neighbor_nums;
|
||||||
|
neighbor_nums.resize(request->number().size());
|
||||||
|
std::transform(request->number().begin(), request->number().end(), neighbor_nums.begin(),
|
||||||
|
[](const google::protobuf::int32 num) { return static_cast<NodeIdType>(num); });
|
||||||
|
std::vector<NodeType> neighbor_types;
|
||||||
|
neighbor_types.resize(request->type().size());
|
||||||
|
std::transform(request->type().begin(), request->type().end(), neighbor_types.begin(),
|
||||||
|
[](const google::protobuf::int32 type) { return static_cast<NodeType>(type); });
|
||||||
|
std::shared_ptr<Tensor> tensor;
|
||||||
|
RETURN_IF_NOT_OK(graph_data_impl_->GetSampledNeighbors(node_list, neighbor_nums, neighbor_types, &tensor));
|
||||||
|
TensorPb *result = response->add_result_data();
|
||||||
|
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataServiceImpl::GetNegSampledNeighbors(const GnnGraphDataRequestPb *request,
|
||||||
|
GnnGraphDataResponsePb *response) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty");
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(request->number_size() == 1, "The number of neighbor number is not 1");
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() == 1, "The number of neighbor types is not 1");
|
||||||
|
|
||||||
|
std::vector<NodeIdType> node_list;
|
||||||
|
node_list.resize(request->id().size());
|
||||||
|
std::transform(request->id().begin(), request->id().end(), node_list.begin(),
|
||||||
|
[](const google::protobuf::int32 id) { return static_cast<NodeIdType>(id); });
|
||||||
|
std::shared_ptr<Tensor> tensor;
|
||||||
|
RETURN_IF_NOT_OK(graph_data_impl_->GetNegSampledNeighbors(node_list, static_cast<NodeIdType>(request->number()[0]),
|
||||||
|
static_cast<NodeType>(request->type()[0]), &tensor));
|
||||||
|
TensorPb *result = response->add_result_data();
|
||||||
|
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataServiceImpl::RandomWalk(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(request->id_size() > 0, "The input node id is empty");
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(request->type_size() > 0, "The input meta path is empty");
|
||||||
|
|
||||||
|
std::vector<NodeIdType> node_list;
|
||||||
|
node_list.resize(request->id().size());
|
||||||
|
std::transform(request->id().begin(), request->id().end(), node_list.begin(),
|
||||||
|
[](const google::protobuf::int32 id) { return static_cast<NodeIdType>(id); });
|
||||||
|
std::vector<NodeType> meta_path;
|
||||||
|
meta_path.resize(request->type().size());
|
||||||
|
std::transform(request->type().begin(), request->type().end(), meta_path.begin(),
|
||||||
|
[](const google::protobuf::int32 type) { return static_cast<NodeType>(type); });
|
||||||
|
std::shared_ptr<Tensor> tensor;
|
||||||
|
RETURN_IF_NOT_OK(graph_data_impl_->RandomWalk(node_list, meta_path, request->random_walk().p(),
|
||||||
|
request->random_walk().q(), request->random_walk().default_id(),
|
||||||
|
&tensor));
|
||||||
|
TensorPb *result = response->add_result_data();
|
||||||
|
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataServiceImpl::GetNodeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) {
|
||||||
|
std::shared_ptr<Tensor> nodes;
|
||||||
|
RETURN_IF_NOT_OK(PbToTensor(&request->id_tensor(), &nodes));
|
||||||
|
for (const auto &type : request->type()) {
|
||||||
|
std::shared_ptr<Tensor> tensor;
|
||||||
|
RETURN_IF_NOT_OK(graph_data_impl_->GetNodeFeatureSharedMemory(nodes, type, &tensor));
|
||||||
|
TensorPb *result = response->add_result_data();
|
||||||
|
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphDataServiceImpl::GetEdgeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response) {
|
||||||
|
std::shared_ptr<Tensor> edges;
|
||||||
|
RETURN_IF_NOT_OK(PbToTensor(&request->id_tensor(), &edges));
|
||||||
|
for (const auto &type : request->type()) {
|
||||||
|
std::shared_ptr<Tensor> tensor;
|
||||||
|
RETURN_IF_NOT_OK(graph_data_impl_->GetEdgeFeatureSharedMemory(edges, type, &tensor));
|
||||||
|
TensorPb *result = response->add_result_data();
|
||||||
|
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gnn
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,70 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
|
||||||
|
#include "proto/gnn_graph_data.grpc.pb.h"
|
||||||
|
#include "proto/gnn_graph_data.pb.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
namespace gnn {
|
||||||
|
|
||||||
|
class GraphDataServer;
|
||||||
|
|
||||||
|
// class GraphDataServiceImpl : public GnnGraphData::Service {
|
||||||
|
class GraphDataServiceImpl {
|
||||||
|
public:
|
||||||
|
GraphDataServiceImpl(GraphDataServer *server, GraphDataImpl *graph_data_impl);
|
||||||
|
~GraphDataServiceImpl() = default;
|
||||||
|
|
||||||
|
grpc::Status ClientRegister(grpc::ServerContext *context, const GnnClientRegisterRequestPb *request,
|
||||||
|
GnnClientRegisterResponsePb *response);
|
||||||
|
|
||||||
|
grpc::Status ClientUnRegister(grpc::ServerContext *context, const GnnClientUnRegisterRequestPb *request,
|
||||||
|
GnnClientUnRegisterResponsePb *response);
|
||||||
|
|
||||||
|
grpc::Status GetGraphData(grpc::ServerContext *context, const GnnGraphDataRequestPb *request,
|
||||||
|
GnnGraphDataResponsePb *response);
|
||||||
|
|
||||||
|
grpc::Status GetMetaInfo(grpc::ServerContext *context, const GnnMetaInfoRequestPb *request,
|
||||||
|
GnnMetaInfoResponsePb *response);
|
||||||
|
|
||||||
|
Status GetAllNodes(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
|
||||||
|
Status GetAllEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
|
||||||
|
Status GetNodesFromEdges(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
|
||||||
|
Status GetAllNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
|
||||||
|
Status GetSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
|
||||||
|
Status GetNegSampledNeighbors(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
|
||||||
|
Status RandomWalk(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
|
||||||
|
Status GetNodeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
|
||||||
|
Status GetEdgeFeature(const GnnGraphDataRequestPb *request, GnnGraphDataResponsePb *response);
|
||||||
|
|
||||||
|
private:
|
||||||
|
Status FillDefaultFeature(GnnClientRegisterResponsePb *response);
|
||||||
|
|
||||||
|
GraphDataServer *server_;
|
||||||
|
GraphDataImpl *graph_data_impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gnn
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_SERVICE_IMPL_H_
|
|
@ -0,0 +1,106 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "minddata/dataset/engine/gnn/graph_feature_parser.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "mindspore/ccsrc/minddata/mindrecord/include/shard_error.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
namespace gnn {
|
||||||
|
|
||||||
|
using mindrecord::MSRStatus;
|
||||||
|
|
||||||
|
GraphFeatureParser::GraphFeatureParser(const ShardColumn &shard_column) {
|
||||||
|
shard_column_ = std::make_unique<ShardColumn>(shard_column);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphFeatureParser::LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &col_blob,
|
||||||
|
std::shared_ptr<Tensor> *tensor) {
|
||||||
|
const unsigned char *data = nullptr;
|
||||||
|
std::unique_ptr<unsigned char[]> data_ptr;
|
||||||
|
uint64_t n_bytes = 0, col_type_size = 1;
|
||||||
|
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
|
||||||
|
std::vector<int64_t> column_shape;
|
||||||
|
MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
|
||||||
|
&col_type_size, &column_shape);
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key);
|
||||||
|
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(std::move(TensorShape({static_cast<dsize_t>(n_bytes / col_type_size)})),
|
||||||
|
std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])),
|
||||||
|
data, tensor));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
Status GraphFeatureParser::LoadFeatureToSharedMemory(const std::string &key, const std::vector<uint8_t> &col_blob,
|
||||||
|
GraphSharedMemory *shared_memory,
|
||||||
|
std::shared_ptr<Tensor> *out_tensor) {
|
||||||
|
const unsigned char *data = nullptr;
|
||||||
|
std::unique_ptr<unsigned char[]> data_ptr;
|
||||||
|
uint64_t n_bytes = 0, col_type_size = 1;
|
||||||
|
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
|
||||||
|
std::vector<int64_t> column_shape;
|
||||||
|
MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
|
||||||
|
&col_type_size, &column_shape);
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key);
|
||||||
|
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
|
||||||
|
std::shared_ptr<Tensor> tensor;
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateEmpty(std::move(TensorShape({2})), std::move(DataType(DataType::DE_INT64)), &tensor));
|
||||||
|
auto fea_itr = tensor->begin<int64_t>();
|
||||||
|
int64_t offset = 0;
|
||||||
|
RETURN_IF_NOT_OK(shared_memory->InsertData(data, n_bytes, &offset));
|
||||||
|
*fea_itr = offset;
|
||||||
|
++fea_itr;
|
||||||
|
*fea_itr = n_bytes;
|
||||||
|
*out_tensor = std::move(tensor);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
Status GraphFeatureParser::LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &col_blob,
|
||||||
|
std::vector<int32_t> *indices) {
|
||||||
|
const unsigned char *data = nullptr;
|
||||||
|
std::unique_ptr<unsigned char[]> data_ptr;
|
||||||
|
uint64_t n_bytes = 0, col_type_size = 1;
|
||||||
|
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
|
||||||
|
std::vector<int64_t> column_shape;
|
||||||
|
MSRStatus rs = shard_column_->GetColumnValueByName(key, col_blob, {}, &data, &data_ptr, &n_bytes, &col_type,
|
||||||
|
&col_type_size, &column_shape);
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key);
|
||||||
|
|
||||||
|
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
|
||||||
|
|
||||||
|
for (int i = 0; i < n_bytes; i += col_type_size) {
|
||||||
|
int32_t feature_ind = -1;
|
||||||
|
if (col_type == mindrecord::ColumnInt32) {
|
||||||
|
feature_ind = *(reinterpret_cast<const int32_t *>(data + i));
|
||||||
|
} else if (col_type == mindrecord::ColumnInt64) {
|
||||||
|
feature_ind = *(reinterpret_cast<const int64_t *>(data + i));
|
||||||
|
} else {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Feature Index needs to be int32/int64 type!");
|
||||||
|
}
|
||||||
|
if (feature_ind >= 0) indices->push_back(feature_ind);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gnn
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,67 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_
|
||||||
|
#include <memory>
|
||||||
|
#include <queue>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/core/data_type.h"
|
||||||
|
#include "minddata/dataset/core/tensor.h"
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
|
||||||
|
#endif
|
||||||
|
#include "minddata/dataset/engine/gnn/feature.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
#include "minddata/mindrecord/include/shard_column.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
namespace gnn {
|
||||||
|
|
||||||
|
using mindrecord::ShardColumn;
|
||||||
|
|
||||||
|
class GraphFeatureParser {
|
||||||
|
public:
|
||||||
|
explicit GraphFeatureParser(const ShardColumn &shard_column);
|
||||||
|
|
||||||
|
~GraphFeatureParser() = default;
|
||||||
|
|
||||||
|
// @param std::string key - column name
|
||||||
|
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
|
||||||
|
// @param std::vector<int32_t> *ind - return value, list of feature index in int32_t
|
||||||
|
// @return Status - the status code
|
||||||
|
Status LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &blob, std::vector<int32_t> *ind);
|
||||||
|
|
||||||
|
// @param std::string &key - column name
|
||||||
|
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
|
||||||
|
// @param std::shared_ptr<Tensor> *tensor - return value feature tensor
|
||||||
|
// @return Status - the status code
|
||||||
|
Status LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &blob, std::shared_ptr<Tensor> *tensor);
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
Status LoadFeatureToSharedMemory(const std::string &key, const std::vector<uint8_t> &col_blob,
|
||||||
|
GraphSharedMemory *shared_memory, std::shared_ptr<Tensor> *out_tensor);
|
||||||
|
#endif
|
||||||
|
private:
|
||||||
|
std::unique_ptr<ShardColumn> shard_column_;
|
||||||
|
};
|
||||||
|
} // namespace gnn
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_FEATURE_PARSER_H_
|
|
@ -13,41 +13,42 @@
|
||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
#include "minddata/dataset/engine/gnn/graph_loader.h"
|
||||||
|
|
||||||
#include <future>
|
#include <future>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "minddata/dataset/engine/gnn/graph_loader.h"
|
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
|
||||||
#include "mindspore/ccsrc/minddata/mindrecord/include/shard_error.h"
|
|
||||||
#include "minddata/dataset/engine/gnn/local_edge.h"
|
#include "minddata/dataset/engine/gnn/local_edge.h"
|
||||||
#include "minddata/dataset/engine/gnn/local_node.h"
|
#include "minddata/dataset/engine/gnn/local_node.h"
|
||||||
#include "minddata/dataset/util/task_manager.h"
|
#include "minddata/dataset/util/task_manager.h"
|
||||||
|
#include "minddata/mindrecord/include/shard_error.h"
|
||||||
|
|
||||||
using ShardTuple = std::vector<std::tuple<std::vector<uint8_t>, mindspore::mindrecord::json>>;
|
using ShardTuple = std::vector<std::tuple<std::vector<uint8_t>, mindspore::mindrecord::json>>;
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
namespace gnn {
|
namespace gnn {
|
||||||
|
|
||||||
using mindrecord::MSRStatus;
|
using mindrecord::MSRStatus;
|
||||||
|
|
||||||
GraphLoader::GraphLoader(std::string mr_filepath, int32_t num_workers)
|
GraphLoader::GraphLoader(GraphDataImpl *graph_impl, std::string mr_filepath, int32_t num_workers, bool server_mode)
|
||||||
: mr_path_(mr_filepath),
|
: graph_impl_(graph_impl),
|
||||||
|
mr_path_(mr_filepath),
|
||||||
num_workers_(num_workers),
|
num_workers_(num_workers),
|
||||||
row_id_(0),
|
row_id_(0),
|
||||||
shard_reader_(nullptr),
|
shard_reader_(nullptr),
|
||||||
|
graph_feature_parser_(nullptr),
|
||||||
keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {}
|
keys_({"first_id", "second_id", "third_id", "attribute", "type", "node_feature_index", "edge_feature_index"}) {}
|
||||||
|
|
||||||
Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, NodeTypeMap *n_type_map,
|
Status GraphLoader::GetNodesAndEdges() {
|
||||||
EdgeTypeMap *e_type_map, NodeFeatureMap *n_feature_map,
|
NodeIdMap *n_id_map = &graph_impl_->node_id_map_;
|
||||||
EdgeFeatureMap *e_feature_map, DefaultNodeFeatureMap *default_node_feature_map,
|
EdgeIdMap *e_id_map = &graph_impl_->edge_id_map_;
|
||||||
DefaultEdgeFeatureMap *default_edge_feature_map) {
|
|
||||||
for (std::deque<std::shared_ptr<Node>> &dq : n_deques_) {
|
for (std::deque<std::shared_ptr<Node>> &dq : n_deques_) {
|
||||||
while (dq.empty() == false) {
|
while (dq.empty() == false) {
|
||||||
std::shared_ptr<Node> node_ptr = dq.front();
|
std::shared_ptr<Node> node_ptr = dq.front();
|
||||||
n_id_map->insert({node_ptr->id(), node_ptr});
|
n_id_map->insert({node_ptr->id(), node_ptr});
|
||||||
(*n_type_map)[node_ptr->type()].push_back(node_ptr->id());
|
graph_impl_->node_type_map_[node_ptr->type()].push_back(node_ptr->id());
|
||||||
dq.pop_front();
|
dq.pop_front();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -63,15 +64,15 @@ Status GraphLoader::GetNodesAndEdges(NodeIdMap *n_id_map, EdgeIdMap *e_id_map, N
|
||||||
RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second}));
|
RETURN_IF_NOT_OK(edge_ptr->SetNode({src_itr->second, dst_itr->second}));
|
||||||
RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second));
|
RETURN_IF_NOT_OK(src_itr->second->AddNeighbor(dst_itr->second));
|
||||||
e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_
|
e_id_map->insert({edge_ptr->id(), edge_ptr}); // add edge to edge_id_map_
|
||||||
(*e_type_map)[edge_ptr->type()].push_back(edge_ptr->id());
|
graph_impl_->edge_type_map_[edge_ptr->type()].push_back(edge_ptr->id());
|
||||||
dq.pop_front();
|
dq.pop_front();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto &itr : *n_type_map) itr.second.shrink_to_fit();
|
for (auto &itr : graph_impl_->node_type_map_) itr.second.shrink_to_fit();
|
||||||
for (auto &itr : *e_type_map) itr.second.shrink_to_fit();
|
for (auto &itr : graph_impl_->edge_type_map_) itr.second.shrink_to_fit();
|
||||||
|
|
||||||
MergeFeatureMaps(n_feature_map, e_feature_map, default_node_feature_map, default_edge_feature_map);
|
MergeFeatureMaps();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,13 +93,26 @@ Status GraphLoader::InitAndLoad() {
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetShardHeader()->GetSchemaCount() > 0, "No schema found!");
|
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetShardHeader()->GetSchemaCount() > 0, "No schema found!");
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Launch(true) == MSRStatus::SUCCESS, "fail to launch mr");
|
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->Launch(true) == MSRStatus::SUCCESS, "fail to launch mr");
|
||||||
|
|
||||||
mindrecord::json schema = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema())["schema"];
|
graph_impl_->data_schema_ = (shard_reader_->GetShardHeader()->GetSchemas()[0]->GetSchema());
|
||||||
|
mindrecord::json schema = graph_impl_->data_schema_["schema"];
|
||||||
for (const std::string &key : keys_) {
|
for (const std::string &key : keys_) {
|
||||||
if (schema.find(key) == schema.end()) {
|
if (schema.find(key) == schema.end()) {
|
||||||
RETURN_STATUS_UNEXPECTED(key + ":doesn't exist in schema:" + schema.dump());
|
RETURN_STATUS_UNEXPECTED(key + ":doesn't exist in schema:" + schema.dump());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (graph_impl_->server_mode_) {
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
int64_t total_blob_size = 0;
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(shard_reader_->GetTotalBlobSize(&total_blob_size) == MSRStatus::SUCCESS,
|
||||||
|
"failed to get total blob size");
|
||||||
|
graph_impl_->graph_shared_memory_ = std::make_unique<GraphSharedMemory>(total_blob_size, mr_path_);
|
||||||
|
RETURN_IF_NOT_OK(graph_impl_->graph_shared_memory_->CreateSharedMemory());
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
graph_feature_parser_ = std::make_unique<GraphFeatureParser>(*shard_reader_->GetShardColumn());
|
||||||
|
|
||||||
// launching worker threads
|
// launching worker threads
|
||||||
for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) {
|
for (int wkr_id = 0; wkr_id < num_workers_; ++wkr_id) {
|
||||||
RETURN_IF_NOT_OK(vg.CreateAsyncTask("GraphLoader", std::bind(&GraphLoader::WorkerEntry, this, wkr_id)));
|
RETURN_IF_NOT_OK(vg.CreateAsyncTask("GraphLoader", std::bind(&GraphLoader::WorkerEntry, this, wkr_id)));
|
||||||
|
@ -116,18 +130,39 @@ Status GraphLoader::LoadNode(const std::vector<uint8_t> &col_blob, const mindrec
|
||||||
NodeType node_type = static_cast<NodeType>(col_jsn["type"]);
|
NodeType node_type = static_cast<NodeType>(col_jsn["type"]);
|
||||||
(*node) = std::make_shared<LocalNode>(node_id, node_type);
|
(*node) = std::make_shared<LocalNode>(node_id, node_type);
|
||||||
std::vector<int32_t> indices;
|
std::vector<int32_t> indices;
|
||||||
RETURN_IF_NOT_OK(LoadFeatureIndex("node_feature_index", col_blob, col_jsn, &indices));
|
RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("node_feature_index", col_blob, &indices));
|
||||||
|
if (graph_impl_->server_mode_) {
|
||||||
for (int32_t ind : indices) {
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
std::shared_ptr<Tensor> tensor;
|
for (int32_t ind : indices) {
|
||||||
RETURN_IF_NOT_OK(LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor));
|
std::shared_ptr<Tensor> tensor_sm;
|
||||||
RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared<Feature>(ind, tensor)));
|
RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureToSharedMemory(
|
||||||
(*feature_map)[node_type].insert(ind);
|
"node_feature_" + std::to_string(ind), col_blob, graph_impl_->graph_shared_memory_.get(), &tensor_sm));
|
||||||
if ((*default_feature)[ind] == nullptr) {
|
RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared<Feature>(ind, tensor_sm, true)));
|
||||||
std::shared_ptr<Tensor> zero_tensor;
|
(*feature_map)[node_type].insert(ind);
|
||||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor));
|
if ((*default_feature)[ind] == nullptr) {
|
||||||
RETURN_IF_NOT_OK(zero_tensor->Zero());
|
std::shared_ptr<Tensor> tensor;
|
||||||
(*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
|
RETURN_IF_NOT_OK(
|
||||||
|
graph_feature_parser_->LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, &tensor));
|
||||||
|
std::shared_ptr<Tensor> zero_tensor;
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor));
|
||||||
|
RETURN_IF_NOT_OK(zero_tensor->Zero());
|
||||||
|
(*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
} else {
|
||||||
|
for (int32_t ind : indices) {
|
||||||
|
std::shared_ptr<Tensor> tensor;
|
||||||
|
RETURN_IF_NOT_OK(
|
||||||
|
graph_feature_parser_->LoadFeatureTensor("node_feature_" + std::to_string(ind), col_blob, &tensor));
|
||||||
|
RETURN_IF_NOT_OK((*node)->UpdateFeature(std::make_shared<Feature>(ind, tensor)));
|
||||||
|
(*feature_map)[node_type].insert(ind);
|
||||||
|
if ((*default_feature)[ind] == nullptr) {
|
||||||
|
std::shared_ptr<Tensor> zero_tensor;
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor));
|
||||||
|
RETURN_IF_NOT_OK(zero_tensor->Zero());
|
||||||
|
(*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -143,63 +178,42 @@ Status GraphLoader::LoadEdge(const std::vector<uint8_t> &col_blob, const mindrec
|
||||||
std::shared_ptr<Node> dst = std::make_shared<LocalNode>(dst_id, -1);
|
std::shared_ptr<Node> dst = std::make_shared<LocalNode>(dst_id, -1);
|
||||||
(*edge) = std::make_shared<LocalEdge>(edge_id, edge_type, src, dst);
|
(*edge) = std::make_shared<LocalEdge>(edge_id, edge_type, src, dst);
|
||||||
std::vector<int32_t> indices;
|
std::vector<int32_t> indices;
|
||||||
RETURN_IF_NOT_OK(LoadFeatureIndex("edge_feature_index", col_blob, col_jsn, &indices));
|
RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureIndex("edge_feature_index", col_blob, &indices));
|
||||||
for (int32_t ind : indices) {
|
if (graph_impl_->server_mode_) {
|
||||||
std::shared_ptr<Tensor> tensor;
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
RETURN_IF_NOT_OK(LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, col_jsn, &tensor));
|
for (int32_t ind : indices) {
|
||||||
RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared<Feature>(ind, tensor)));
|
std::shared_ptr<Tensor> tensor_sm;
|
||||||
(*feature_map)[edge_type].insert(ind);
|
RETURN_IF_NOT_OK(graph_feature_parser_->LoadFeatureToSharedMemory(
|
||||||
if ((*default_feature)[ind] == nullptr) {
|
"edge_feature_" + std::to_string(ind), col_blob, graph_impl_->graph_shared_memory_.get(), &tensor_sm));
|
||||||
std::shared_ptr<Tensor> zero_tensor;
|
RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared<Feature>(ind, tensor_sm, true)));
|
||||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor));
|
(*feature_map)[edge_type].insert(ind);
|
||||||
RETURN_IF_NOT_OK(zero_tensor->Zero());
|
if ((*default_feature)[ind] == nullptr) {
|
||||||
(*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
|
std::shared_ptr<Tensor> tensor;
|
||||||
|
RETURN_IF_NOT_OK(
|
||||||
|
graph_feature_parser_->LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, &tensor));
|
||||||
|
std::shared_ptr<Tensor> zero_tensor;
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor));
|
||||||
|
RETURN_IF_NOT_OK(zero_tensor->Zero());
|
||||||
|
(*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
} else {
|
||||||
|
for (int32_t ind : indices) {
|
||||||
|
std::shared_ptr<Tensor> tensor;
|
||||||
|
RETURN_IF_NOT_OK(
|
||||||
|
graph_feature_parser_->LoadFeatureTensor("edge_feature_" + std::to_string(ind), col_blob, &tensor));
|
||||||
|
RETURN_IF_NOT_OK((*edge)->UpdateFeature(std::make_shared<Feature>(ind, tensor)));
|
||||||
|
(*feature_map)[edge_type].insert(ind);
|
||||||
|
if ((*default_feature)[ind] == nullptr) {
|
||||||
|
std::shared_ptr<Tensor> zero_tensor;
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateEmpty(tensor->shape(), tensor->type(), &zero_tensor));
|
||||||
|
RETURN_IF_NOT_OK(zero_tensor->Zero());
|
||||||
|
(*default_feature)[ind] = std::make_shared<Feature>(ind, zero_tensor);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status GraphLoader::LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &col_blob,
|
|
||||||
const mindrecord::json &col_jsn, std::shared_ptr<Tensor> *tensor) {
|
|
||||||
const unsigned char *data = nullptr;
|
|
||||||
std::unique_ptr<unsigned char[]> data_ptr;
|
|
||||||
uint64_t n_bytes = 0, col_type_size = 1;
|
|
||||||
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
|
|
||||||
std::vector<int64_t> column_shape;
|
|
||||||
MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName(
|
|
||||||
key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape);
|
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column" + key);
|
|
||||||
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
|
|
||||||
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(std::move(TensorShape({static_cast<dsize_t>(n_bytes / col_type_size)})),
|
|
||||||
std::move(DataType(mindrecord::ColumnDataTypeNameNormalized[col_type])),
|
|
||||||
data, tensor));
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status GraphLoader::LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &col_blob,
|
|
||||||
const mindrecord::json &col_jsn, std::vector<int32_t> *indices) {
|
|
||||||
const unsigned char *data = nullptr;
|
|
||||||
std::unique_ptr<unsigned char[]> data_ptr;
|
|
||||||
uint64_t n_bytes = 0, col_type_size = 1;
|
|
||||||
mindrecord::ColumnDataType col_type = mindrecord::ColumnNoDataType;
|
|
||||||
std::vector<int64_t> column_shape;
|
|
||||||
MSRStatus rs = shard_reader_->GetShardColumn()->GetColumnValueByName(
|
|
||||||
key, col_blob, col_jsn, &data, &data_ptr, &n_bytes, &col_type, &col_type_size, &column_shape);
|
|
||||||
CHECK_FAIL_RETURN_UNEXPECTED(rs == mindrecord::SUCCESS, "fail to load column:" + key);
|
|
||||||
|
|
||||||
if (data == nullptr) data = reinterpret_cast<const unsigned char *>(&data_ptr[0]);
|
|
||||||
|
|
||||||
for (int i = 0; i < n_bytes; i += col_type_size) {
|
|
||||||
int32_t feature_ind = -1;
|
|
||||||
if (col_type == mindrecord::ColumnInt32) {
|
|
||||||
feature_ind = *(reinterpret_cast<const int32_t *>(data + i));
|
|
||||||
} else if (col_type == mindrecord::ColumnInt64) {
|
|
||||||
feature_ind = *(reinterpret_cast<const int64_t *>(data + i));
|
|
||||||
} else {
|
|
||||||
RETURN_STATUS_UNEXPECTED("Feature Index needs to be int32/int64 type!");
|
|
||||||
}
|
|
||||||
if (feature_ind >= 0) indices->push_back(feature_ind);
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -234,21 +248,19 @@ Status GraphLoader::WorkerEntry(int32_t worker_id) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
void GraphLoader::MergeFeatureMaps(NodeFeatureMap *n_feature_map, EdgeFeatureMap *e_feature_map,
|
void GraphLoader::MergeFeatureMaps() {
|
||||||
DefaultNodeFeatureMap *default_node_feature_map,
|
|
||||||
DefaultEdgeFeatureMap *default_edge_feature_map) {
|
|
||||||
for (int wkr_id = 0; wkr_id < num_workers_; wkr_id++) {
|
for (int wkr_id = 0; wkr_id < num_workers_; wkr_id++) {
|
||||||
for (auto &m : n_feature_maps_[wkr_id]) {
|
for (auto &m : n_feature_maps_[wkr_id]) {
|
||||||
for (auto &n : m.second) (*n_feature_map)[m.first].insert(n);
|
for (auto &n : m.second) graph_impl_->node_feature_map_[m.first].insert(n);
|
||||||
}
|
}
|
||||||
for (auto &m : e_feature_maps_[wkr_id]) {
|
for (auto &m : e_feature_maps_[wkr_id]) {
|
||||||
for (auto &n : m.second) (*e_feature_map)[m.first].insert(n);
|
for (auto &n : m.second) graph_impl_->edge_feature_map_[m.first].insert(n);
|
||||||
}
|
}
|
||||||
for (auto &m : default_node_feature_maps_[wkr_id]) {
|
for (auto &m : default_node_feature_maps_[wkr_id]) {
|
||||||
(*default_node_feature_map)[m.first] = m.second;
|
graph_impl_->default_node_feature_map_[m.first] = m.second;
|
||||||
}
|
}
|
||||||
for (auto &m : default_edge_feature_maps_[wkr_id]) {
|
for (auto &m : default_edge_feature_maps_[wkr_id]) {
|
||||||
(*default_edge_feature_map)[m.first] = m.second;
|
graph_impl_->default_edge_feature_map_[m.first] = m.second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
n_feature_maps_.clear();
|
n_feature_maps_.clear();
|
||||||
|
|
|
@ -26,10 +26,13 @@
|
||||||
|
|
||||||
#include "minddata/dataset/core/data_type.h"
|
#include "minddata/dataset/core/data_type.h"
|
||||||
#include "minddata/dataset/core/tensor.h"
|
#include "minddata/dataset/core/tensor.h"
|
||||||
#include "minddata/dataset/engine/gnn/feature.h"
|
|
||||||
#include "minddata/dataset/engine/gnn/graph.h"
|
|
||||||
#include "minddata/dataset/engine/gnn/node.h"
|
|
||||||
#include "minddata/dataset/engine/gnn/edge.h"
|
#include "minddata/dataset/engine/gnn/edge.h"
|
||||||
|
#include "minddata/dataset/engine/gnn/feature.h"
|
||||||
|
#include "minddata/dataset/engine/gnn/graph_feature_parser.h"
|
||||||
|
#if !defined(_WIN32) && !defined(_WIN64)
|
||||||
|
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
|
||||||
|
#endif
|
||||||
|
#include "minddata/dataset/engine/gnn/node.h"
|
||||||
#include "minddata/dataset/util/status.h"
|
#include "minddata/dataset/util/status.h"
|
||||||
#include "minddata/mindrecord/include/shard_reader.h"
|
#include "minddata/mindrecord/include/shard_reader.h"
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -46,13 +49,15 @@ using EdgeFeatureMap = std::unordered_map<EdgeType, std::unordered_set<FeatureTy
|
||||||
using DefaultNodeFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>;
|
using DefaultNodeFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>;
|
||||||
using DefaultEdgeFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>;
|
using DefaultEdgeFeatureMap = std::unordered_map<FeatureType, std::shared_ptr<Feature>>;
|
||||||
|
|
||||||
|
class GraphDataImpl;
|
||||||
|
|
||||||
// this class interfaces with the underlying storage format (mindrecord)
|
// this class interfaces with the underlying storage format (mindrecord)
|
||||||
// it returns raw nodes and edges via GetNodesAndEdges
|
// it returns raw nodes and edges via GetNodesAndEdges
|
||||||
// it is then the responsibility of graph to construct itself based on the nodes and edges
|
// it is then the responsibility of graph to construct itself based on the nodes and edges
|
||||||
// if needed, this class could become a base where each derived class handles a specific storage format
|
// if needed, this class could become a base where each derived class handles a specific storage format
|
||||||
class GraphLoader {
|
class GraphLoader {
|
||||||
public:
|
public:
|
||||||
explicit GraphLoader(std::string mr_filepath, int32_t num_workers = 4);
|
GraphLoader(GraphDataImpl *graph_impl, std::string mr_filepath, int32_t num_workers = 4, bool server_mode = false);
|
||||||
|
|
||||||
~GraphLoader() = default;
|
~GraphLoader() = default;
|
||||||
// Init mindrecord and load everything into memory multi-threaded
|
// Init mindrecord and load everything into memory multi-threaded
|
||||||
|
@ -63,8 +68,7 @@ class GraphLoader {
|
||||||
// nodes and edges are added to map without any connection. That's because there nodes and edges are read in
|
// nodes and edges are added to map without any connection. That's because there nodes and edges are read in
|
||||||
// random order. src_node and dst_node in Edge are node_id only with -1 as type.
|
// random order. src_node and dst_node in Edge are node_id only with -1 as type.
|
||||||
// features attached to each node and edge are expected to be filled correctly
|
// features attached to each node and edge are expected to be filled correctly
|
||||||
Status GetNodesAndEdges(NodeIdMap *, EdgeIdMap *, NodeTypeMap *, EdgeTypeMap *, NodeFeatureMap *, EdgeFeatureMap *,
|
Status GetNodesAndEdges();
|
||||||
DefaultNodeFeatureMap *, DefaultEdgeFeatureMap *);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
//
|
//
|
||||||
|
@ -92,29 +96,15 @@ class GraphLoader {
|
||||||
Status LoadEdge(const std::vector<uint8_t> &blob, const mindrecord::json &jsn, std::shared_ptr<Edge> *edge,
|
Status LoadEdge(const std::vector<uint8_t> &blob, const mindrecord::json &jsn, std::shared_ptr<Edge> *edge,
|
||||||
EdgeFeatureMap *feature_map, DefaultEdgeFeatureMap *default_feature);
|
EdgeFeatureMap *feature_map, DefaultEdgeFeatureMap *default_feature);
|
||||||
|
|
||||||
// @param std::string key - column name
|
|
||||||
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
|
|
||||||
// @param mindrecord::json &jsn - contains raw data
|
|
||||||
// @param std::vector<int32_t> *ind - return value, list of feature index in int32_t
|
|
||||||
// @return Status - the status code
|
|
||||||
Status LoadFeatureIndex(const std::string &key, const std::vector<uint8_t> &blob, const mindrecord::json &jsn,
|
|
||||||
std::vector<int32_t> *ind);
|
|
||||||
|
|
||||||
// @param std::string &key - column name
|
|
||||||
// @param std::vector<uint8_t> &blob - contains data in blob field in mindrecord
|
|
||||||
// @param mindrecord::json &jsn - contains raw data
|
|
||||||
// @param std::shared_ptr<Tensor> *tensor - return value feature tensor
|
|
||||||
// @return Status - the status code
|
|
||||||
Status LoadFeatureTensor(const std::string &key, const std::vector<uint8_t> &blob, const mindrecord::json &jsn,
|
|
||||||
std::shared_ptr<Tensor> *tensor);
|
|
||||||
|
|
||||||
// merge NodeFeatureMap and EdgeFeatureMap of each worker into 1
|
// merge NodeFeatureMap and EdgeFeatureMap of each worker into 1
|
||||||
void MergeFeatureMaps(NodeFeatureMap *, EdgeFeatureMap *, DefaultNodeFeatureMap *, DefaultEdgeFeatureMap *);
|
void MergeFeatureMaps();
|
||||||
|
|
||||||
|
GraphDataImpl *graph_impl_;
|
||||||
|
std::string mr_path_;
|
||||||
const int32_t num_workers_;
|
const int32_t num_workers_;
|
||||||
std::atomic_int row_id_;
|
std::atomic_int row_id_;
|
||||||
std::string mr_path_;
|
|
||||||
std::unique_ptr<ShardReader> shard_reader_;
|
std::unique_ptr<ShardReader> shard_reader_;
|
||||||
|
std::unique_ptr<GraphFeatureParser> graph_feature_parser_;
|
||||||
std::vector<std::deque<std::shared_ptr<Node>>> n_deques_;
|
std::vector<std::deque<std::shared_ptr<Node>>> n_deques_;
|
||||||
std::vector<std::deque<std::shared_ptr<Edge>>> e_deques_;
|
std::vector<std::deque<std::shared_ptr<Edge>>> e_deques_;
|
||||||
std::vector<NodeFeatureMap> n_feature_maps_;
|
std::vector<NodeFeatureMap> n_feature_maps_;
|
||||||
|
|
|
@ -0,0 +1,134 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "minddata/dataset/engine/gnn/graph_shared_memory.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
namespace gnn {
|
||||||
|
|
||||||
|
GraphSharedMemory::GraphSharedMemory(int64_t memory_size, key_t memory_key)
|
||||||
|
: memory_size_(memory_size),
|
||||||
|
memory_key_(memory_key),
|
||||||
|
memory_ptr_(nullptr),
|
||||||
|
memory_offset_(0),
|
||||||
|
is_new_create_(false) {
|
||||||
|
std::stringstream stream;
|
||||||
|
stream << std::hex << memory_key_;
|
||||||
|
memory_key_str_ = stream.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphSharedMemory::GraphSharedMemory(int64_t memory_size, const std::string &mr_file)
|
||||||
|
: mr_file_(mr_file),
|
||||||
|
memory_size_(memory_size),
|
||||||
|
memory_key_(-1),
|
||||||
|
memory_ptr_(nullptr),
|
||||||
|
memory_offset_(0),
|
||||||
|
is_new_create_(false) {}
|
||||||
|
|
||||||
|
GraphSharedMemory::~GraphSharedMemory() {
|
||||||
|
if (is_new_create_) {
|
||||||
|
(void)DeleteSharedMemory();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphSharedMemory::CreateSharedMemory() {
|
||||||
|
if (memory_key_ == -1) {
|
||||||
|
// ftok to generate unique key
|
||||||
|
memory_key_ = ftok(mr_file_.data(), kGnnSharedMemoryId);
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(memory_key_ != -1, "Failed to get key of shared memory. file_name:" + mr_file_);
|
||||||
|
std::stringstream stream;
|
||||||
|
stream << std::hex << memory_key_;
|
||||||
|
memory_key_str_ = stream.str();
|
||||||
|
}
|
||||||
|
int shmflg = (0666 | IPC_CREAT | IPC_EXCL);
|
||||||
|
Status s = SharedMemoryImpl(shmflg);
|
||||||
|
if (s.IsOk()) {
|
||||||
|
is_new_create_ = true;
|
||||||
|
MS_LOG(INFO) << "Create shared memory success, key=0x" << memory_key_str_;
|
||||||
|
} else {
|
||||||
|
MS_LOG(WARNING) << "Shared memory with the same key may already exist, key=0x" << memory_key_str_;
|
||||||
|
shmflg = (0666 | IPC_CREAT);
|
||||||
|
s = SharedMemoryImpl(shmflg);
|
||||||
|
if (!s.IsOk()) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Create shared memory fao;ed, key=0x" + memory_key_str_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphSharedMemory::GetSharedMemory() {
|
||||||
|
int shmflg = 0;
|
||||||
|
RETURN_IF_NOT_OK(SharedMemoryImpl(shmflg));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphSharedMemory::DeleteSharedMemory() {
|
||||||
|
int shmid = shmget(memory_key_, 0, 0);
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(shmid != -1, "Failed to get shared memory. key=0x" + memory_key_str_);
|
||||||
|
int result = shmctl(shmid, IPC_RMID, 0);
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(result != -1, "Failed to delete shared memory. key=0x" + memory_key_str_);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphSharedMemory::SharedMemoryImpl(const int &shmflg) {
|
||||||
|
// shmget returns an identifier in shmid
|
||||||
|
int shmid = shmget(memory_key_, memory_size_, shmflg);
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(shmid != -1, "Failed to get shared memory. key=0x" + memory_key_str_);
|
||||||
|
|
||||||
|
// shmat to attach to shared memory
|
||||||
|
auto data = shmat(shmid, reinterpret_cast<void *>(0), 0);
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(data != (char *)(-1), "Failed to address shared memory. key=0x" + memory_key_str_);
|
||||||
|
memory_ptr_ = reinterpret_cast<uint8_t *>(data);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphSharedMemory::InsertData(const uint8_t *data, int64_t len, int64_t *offset) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(data, "Input data is nullptr.");
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(len > 0, "Input len is invalid.");
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> lck(mutex_);
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED((memory_size_ - memory_offset_ >= len),
|
||||||
|
"Insufficient shared memory space to insert data.");
|
||||||
|
if (EOK != memcpy_s(memory_ptr_ + memory_offset_, memory_size_ - memory_offset_, data, len)) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Failed to insert data into shared memory.");
|
||||||
|
}
|
||||||
|
*offset = memory_offset_;
|
||||||
|
memory_offset_ += len;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GraphSharedMemory::GetData(uint8_t *data, int64_t data_len, int64_t offset, int64_t get_data_len) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(data, "Input data is nullptr.");
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(get_data_len > 0, "Input get_data_len is invalid.");
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(data_len >= get_data_len, "Insufficient target address space.");
|
||||||
|
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(memory_size_ >= get_data_len + offset,
|
||||||
|
"get_data_len is too large, beyond the space of shared memory.");
|
||||||
|
if (EOK != memcpy_s(data, data_len, memory_ptr_ + offset, get_data_len)) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Failed to insert data into shared memory.");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gnn
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,72 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_
|
||||||
|
|
||||||
|
#include <sys/ipc.h>
|
||||||
|
#include <sys/shm.h>
|
||||||
|
#include <mutex>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
namespace gnn {
|
||||||
|
|
||||||
|
const int kGnnSharedMemoryId = 65;
|
||||||
|
|
||||||
|
class GraphSharedMemory {
|
||||||
|
public:
|
||||||
|
explicit GraphSharedMemory(int64_t memory_size, key_t memory_key);
|
||||||
|
explicit GraphSharedMemory(int64_t memory_size, const std::string &mr_file);
|
||||||
|
|
||||||
|
~GraphSharedMemory();
|
||||||
|
|
||||||
|
// @param uint8_t** shared_memory - shared memory address
|
||||||
|
// @return Status - the status code
|
||||||
|
Status CreateSharedMemory();
|
||||||
|
|
||||||
|
// @param uint8_t** shared_memory - shared memory address
|
||||||
|
// @return Status - the status code
|
||||||
|
Status GetSharedMemory();
|
||||||
|
|
||||||
|
Status DeleteSharedMemory();
|
||||||
|
|
||||||
|
Status InsertData(const uint8_t *data, int64_t len, int64_t *offset);
|
||||||
|
|
||||||
|
Status GetData(uint8_t *data, int64_t data_len, int64_t offset, int64_t get_data_len);
|
||||||
|
|
||||||
|
key_t memory_key() { return memory_key_; }
|
||||||
|
|
||||||
|
int64_t memory_size() { return memory_size_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
Status SharedMemoryImpl(const int &shmflg);
|
||||||
|
|
||||||
|
std::string mr_file_;
|
||||||
|
int64_t memory_size_;
|
||||||
|
key_t memory_key_;
|
||||||
|
std::string memory_key_str_;
|
||||||
|
uint8_t *memory_ptr_;
|
||||||
|
int64_t memory_offset_;
|
||||||
|
std::mutex mutex_;
|
||||||
|
bool is_new_create_;
|
||||||
|
};
|
||||||
|
} // namespace gnn
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_SHARED_MEMORY_H_
|
|
@ -0,0 +1,82 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "minddata/dataset/engine/gnn/grpc_async_server.h"
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include "minddata/dataset/util/task_manager.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
GrpcAsyncServer::GrpcAsyncServer(const std::string &host, int32_t port) : host_(host), port_(port) {}
|
||||||
|
|
||||||
|
GrpcAsyncServer::~GrpcAsyncServer() { Stop(); }
|
||||||
|
|
||||||
|
Status GrpcAsyncServer::Run() {
|
||||||
|
std::string server_address = host_ + ":" + std::to_string(port_);
|
||||||
|
grpc::ServerBuilder builder;
|
||||||
|
// Default message size for gRPC is 4MB. Increase it to 2g-1
|
||||||
|
builder.SetMaxReceiveMessageSize(std::numeric_limits<int32_t>::max());
|
||||||
|
builder.AddChannelArgument(GRPC_ARG_ALLOW_REUSEPORT, 0);
|
||||||
|
int port_tcpip = 0;
|
||||||
|
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials(), &port_tcpip);
|
||||||
|
RETURN_IF_NOT_OK(RegisterService(&builder));
|
||||||
|
cq_ = builder.AddCompletionQueue();
|
||||||
|
server_ = builder.BuildAndStart();
|
||||||
|
if (server_) {
|
||||||
|
MS_LOG(INFO) << "Server listening on " << server_address;
|
||||||
|
} else {
|
||||||
|
std::string errMsg = "Fail to start server. ";
|
||||||
|
if (port_tcpip != port_) {
|
||||||
|
errMsg += "Unable to bind to address " + server_address + ".";
|
||||||
|
}
|
||||||
|
RETURN_STATUS_UNEXPECTED(errMsg);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GrpcAsyncServer::HandleRequest() {
|
||||||
|
bool success;
|
||||||
|
void *tag;
|
||||||
|
// We loop through the grpc queue. Each connection if successful
|
||||||
|
// will come back with our own tag which is an instance of CallData
|
||||||
|
// and we simply call its functor. But first we need to create these instances
|
||||||
|
// and inject them into the grpc queue.
|
||||||
|
RETURN_IF_NOT_OK(EnqueueRequest());
|
||||||
|
while (cq_->Next(&tag, &success)) {
|
||||||
|
RETURN_IF_INTERRUPTED();
|
||||||
|
if (success) {
|
||||||
|
RETURN_IF_NOT_OK(ProcessRequest(tag));
|
||||||
|
} else {
|
||||||
|
MS_LOG(DEBUG) << "cq_->Next failed.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
void GrpcAsyncServer::Stop() {
|
||||||
|
if (server_) {
|
||||||
|
server_->Shutdown();
|
||||||
|
}
|
||||||
|
// Always shutdown the completion queue after the server.
|
||||||
|
if (cq_) {
|
||||||
|
cq_->Shutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,59 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "grpcpp/grpcpp.h"
|
||||||
|
#include "grpcpp/impl/codegen/async_unary_call.h"
|
||||||
|
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
/// \brief Async server base class
|
||||||
|
class GrpcAsyncServer {
|
||||||
|
public:
|
||||||
|
explicit GrpcAsyncServer(const std::string &host, int32_t port);
|
||||||
|
virtual ~GrpcAsyncServer();
|
||||||
|
/// \brief Brings up gRPC server
|
||||||
|
/// \return none
|
||||||
|
Status Run();
|
||||||
|
/// \brief Entry function to handle async server request
|
||||||
|
Status HandleRequest();
|
||||||
|
|
||||||
|
void Stop();
|
||||||
|
|
||||||
|
virtual Status RegisterService(grpc::ServerBuilder *builder) = 0;
|
||||||
|
|
||||||
|
virtual Status EnqueueRequest() = 0;
|
||||||
|
|
||||||
|
virtual Status ProcessRequest(void *tag) = 0;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
int32_t port_;
|
||||||
|
std::string host_;
|
||||||
|
std::unique_ptr<grpc::ServerCompletionQueue> cq_;
|
||||||
|
std::unique_ptr<grpc::Server> server_;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRPC_ASYNC_SERVER_H_
|
|
@ -44,6 +44,7 @@ Status LocalEdge::UpdateFeature(const std::shared_ptr<Feature> &feature) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gnn
|
} // namespace gnn
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -20,10 +20,10 @@
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "minddata/dataset/util/status.h"
|
|
||||||
#include "minddata/dataset/engine/gnn/edge.h"
|
#include "minddata/dataset/engine/gnn/edge.h"
|
||||||
#include "minddata/dataset/engine/gnn/feature.h"
|
#include "minddata/dataset/engine/gnn/feature.h"
|
||||||
#include "minddata/dataset/engine/gnn/node.h"
|
#include "minddata/dataset/engine/gnn/node.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
|
@ -20,9 +20,9 @@
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "minddata/dataset/util/status.h"
|
|
||||||
#include "minddata/dataset/engine/gnn/node.h"
|
#include "minddata/dataset/engine/gnn/node.h"
|
||||||
#include "minddata/dataset/engine/gnn/feature.h"
|
#include "minddata/dataset/engine/gnn/feature.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
|
@ -20,8 +20,8 @@
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "minddata/dataset/util/status.h"
|
|
||||||
#include "minddata/dataset/engine/gnn/feature.h"
|
#include "minddata/dataset/engine/gnn/feature.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
|
@ -0,0 +1,84 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "minddata/dataset/engine/gnn/tensor_proto.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <utility>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
const std::unordered_map<DataTypePb, DataType::Type> g_pb2datatype_map{
|
||||||
|
{DataTypePb::DE_PB_UNKNOWN, DataType::DE_UNKNOWN}, {DataTypePb::DE_PB_BOOL, DataType::DE_BOOL},
|
||||||
|
{DataTypePb::DE_PB_INT8, DataType::DE_INT8}, {DataTypePb::DE_PB_UINT8, DataType::DE_UINT8},
|
||||||
|
{DataTypePb::DE_PB_INT16, DataType::DE_INT16}, {DataTypePb::DE_PB_UINT16, DataType::DE_UINT16},
|
||||||
|
{DataTypePb::DE_PB_INT32, DataType::DE_INT32}, {DataTypePb::DE_PB_UINT32, DataType::DE_UINT32},
|
||||||
|
{DataTypePb::DE_PB_INT64, DataType::DE_INT64}, {DataTypePb::DE_PB_UINT64, DataType::DE_UINT64},
|
||||||
|
{DataTypePb::DE_PB_FLOAT16, DataType::DE_FLOAT16}, {DataTypePb::DE_PB_FLOAT32, DataType::DE_FLOAT32},
|
||||||
|
{DataTypePb::DE_PB_FLOAT64, DataType::DE_FLOAT64}, {DataTypePb::DE_PB_STRING, DataType::DE_STRING},
|
||||||
|
};
|
||||||
|
|
||||||
|
const std::unordered_map<DataType::Type, DataTypePb> g_datatype2pb_map{
|
||||||
|
{DataType::DE_UNKNOWN, DataTypePb::DE_PB_UNKNOWN}, {DataType::DE_BOOL, DataTypePb::DE_PB_BOOL},
|
||||||
|
{DataType::DE_INT8, DataTypePb::DE_PB_INT8}, {DataType::DE_UINT8, DataTypePb::DE_PB_UINT8},
|
||||||
|
{DataType::DE_INT16, DataTypePb::DE_PB_INT16}, {DataType::DE_UINT16, DataTypePb::DE_PB_UINT16},
|
||||||
|
{DataType::DE_INT32, DataTypePb::DE_PB_INT32}, {DataType::DE_UINT32, DataTypePb::DE_PB_UINT32},
|
||||||
|
{DataType::DE_INT64, DataTypePb::DE_PB_INT64}, {DataType::DE_UINT64, DataTypePb::DE_PB_UINT64},
|
||||||
|
{DataType::DE_FLOAT16, DataTypePb::DE_PB_FLOAT16}, {DataType::DE_FLOAT32, DataTypePb::DE_PB_FLOAT32},
|
||||||
|
{DataType::DE_FLOAT64, DataTypePb::DE_PB_FLOAT64}, {DataType::DE_STRING, DataTypePb::DE_PB_STRING},
|
||||||
|
};
|
||||||
|
|
||||||
|
Status TensorToPb(const std::shared_ptr<Tensor> tensor, TensorPb *tensor_pb) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(tensor, "Parameter tensor is a null pointer");
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(tensor_pb, "Parameter tensor_pb is a null pointer");
|
||||||
|
|
||||||
|
std::vector<dsize_t> shape = tensor->shape().AsVector();
|
||||||
|
for (auto dim : shape) {
|
||||||
|
tensor_pb->add_dims(static_cast<google::protobuf::int64>(dim));
|
||||||
|
}
|
||||||
|
auto iter = g_datatype2pb_map.find(tensor->type().value());
|
||||||
|
if (iter == g_datatype2pb_map.end()) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Invalid tensor type: " + tensor->type().ToString());
|
||||||
|
}
|
||||||
|
tensor_pb->set_tensor_type(iter->second);
|
||||||
|
tensor_pb->set_data(tensor->GetBuffer(), tensor->SizeInBytes());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PbToTensor(const TensorPb *tensor_pb, std::shared_ptr<Tensor> *tensor) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(tensor_pb, "Parameter tensor_pb is a null pointer");
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(tensor, "Parameter tensor is a null pointer");
|
||||||
|
|
||||||
|
std::vector<dsize_t> shape;
|
||||||
|
shape.resize(tensor_pb->dims().size());
|
||||||
|
std::transform(tensor_pb->dims().begin(), tensor_pb->dims().end(), shape.begin(),
|
||||||
|
[](const google::protobuf::int64 dim) { return static_cast<dsize_t>(dim); });
|
||||||
|
auto iter = g_pb2datatype_map.find(tensor_pb->tensor_type());
|
||||||
|
if (iter == g_pb2datatype_map.end()) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("Invalid Tensor_pb type: " + std::to_string(tensor_pb->tensor_type()));
|
||||||
|
}
|
||||||
|
DataType::Type type = iter->second;
|
||||||
|
std::shared_ptr<Tensor> tensor_out;
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateFromMemory(TensorShape(shape), DataType(type),
|
||||||
|
reinterpret_cast<const unsigned char *>(tensor_pb->data().data()),
|
||||||
|
tensor_pb->data().size(), &tensor_out));
|
||||||
|
*tensor = std::move(tensor_out);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,36 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_TENSOR_PROTO_H_
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_TENSOR_PROTO_H_
|
||||||
|
|
||||||
|
#include <deque>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "proto/gnn_tensor.pb.h"
|
||||||
|
#include "minddata/dataset/core/tensor.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
Status TensorToPb(const std::shared_ptr<Tensor> tensor, TensorPb *tensor_pb);
|
||||||
|
|
||||||
|
Status PbToTensor(const TensorPb *tensor_pb, std::shared_ptr<Tensor> *tensor);
|
||||||
|
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_TENSOR_PROTO_H_
|
|
@ -61,6 +61,7 @@ const std::unordered_map<std::string, ColumnDataType> ColumnDataTypeMap = {
|
||||||
class ShardColumn {
|
class ShardColumn {
|
||||||
public:
|
public:
|
||||||
explicit ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer = true);
|
explicit ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer = true);
|
||||||
|
explicit ShardColumn(const json &schema_json, bool compress_integer = true);
|
||||||
|
|
||||||
~ShardColumn() = default;
|
~ShardColumn() = default;
|
||||||
|
|
||||||
|
@ -72,23 +73,29 @@ class ShardColumn {
|
||||||
std::vector<int64_t> *column_shape);
|
std::vector<int64_t> *column_shape);
|
||||||
|
|
||||||
/// \brief compress blob
|
/// \brief compress blob
|
||||||
std::vector<uint8_t> CompressBlob(const std::vector<uint8_t> &blob);
|
std::vector<uint8_t> CompressBlob(const std::vector<uint8_t> &blob, int64_t *compression_size);
|
||||||
|
|
||||||
/// \brief check if blob compressed
|
/// \brief check if blob compressed
|
||||||
bool CheckCompressBlob() const { return has_compress_blob_; }
|
bool CheckCompressBlob() const { return has_compress_blob_; }
|
||||||
|
|
||||||
|
/// \brief getter
|
||||||
uint64_t GetNumBlobColumn() const { return num_blob_column_; }
|
uint64_t GetNumBlobColumn() const { return num_blob_column_; }
|
||||||
|
|
||||||
|
/// \brief getter
|
||||||
std::vector<std::string> GetColumnName() { return column_name_; }
|
std::vector<std::string> GetColumnName() { return column_name_; }
|
||||||
|
|
||||||
|
/// \brief getter
|
||||||
std::vector<ColumnDataType> GeColumnDataType() { return column_data_type_; }
|
std::vector<ColumnDataType> GeColumnDataType() { return column_data_type_; }
|
||||||
|
|
||||||
|
/// \brief getter
|
||||||
std::vector<std::vector<int64_t>> GetColumnShape() { return column_shape_; }
|
std::vector<std::vector<int64_t>> GetColumnShape() { return column_shape_; }
|
||||||
|
|
||||||
/// \brief get column value from blob
|
/// \brief get column value from blob
|
||||||
MSRStatus GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
|
MSRStatus GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
|
||||||
const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr,
|
const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr,
|
||||||
uint64_t *const n_bytes);
|
uint64_t *const n_bytes);
|
||||||
|
|
||||||
|
/// \brief get column type
|
||||||
std::pair<MSRStatus, ColumnCategory> GetColumnTypeByName(const std::string &column_name,
|
std::pair<MSRStatus, ColumnCategory> GetColumnTypeByName(const std::string &column_name,
|
||||||
ColumnDataType *column_data_type,
|
ColumnDataType *column_data_type,
|
||||||
uint64_t *column_data_type_size,
|
uint64_t *column_data_type_size,
|
||||||
|
@ -99,6 +106,9 @@ class ShardColumn {
|
||||||
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes);
|
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
/// \brief intialization
|
||||||
|
void Init(const json &schema_json, bool compress_integer = true);
|
||||||
|
|
||||||
/// \brief get float value from json
|
/// \brief get float value from json
|
||||||
template <typename T>
|
template <typename T>
|
||||||
MSRStatus GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, bool use_double);
|
MSRStatus GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, bool use_double);
|
||||||
|
|
|
@ -65,6 +65,11 @@ class ShardHeader {
|
||||||
/// \return the Statistic
|
/// \return the Statistic
|
||||||
std::vector<std::shared_ptr<Statistics>> GetStatistics();
|
std::vector<std::shared_ptr<Statistics>> GetStatistics();
|
||||||
|
|
||||||
|
/// \brief add the statistic and save it
|
||||||
|
/// \param[in] statistic info of slim size
|
||||||
|
/// \return null
|
||||||
|
int64_t GetSlimSizeStatistic(const json &slim_size_json);
|
||||||
|
|
||||||
/// \brief get the fields of the index
|
/// \brief get the fields of the index
|
||||||
/// \return the fields of the index
|
/// \return the fields of the index
|
||||||
std::vector<std::pair<uint64_t, std::string>> GetFields();
|
std::vector<std::pair<uint64_t, std::string>> GetFields();
|
||||||
|
@ -114,10 +119,14 @@ class ShardHeader {
|
||||||
|
|
||||||
uint64_t GetPageSize() const { return page_size_; }
|
uint64_t GetPageSize() const { return page_size_; }
|
||||||
|
|
||||||
|
uint64_t GetCompressionSize() const { return compression_size_; }
|
||||||
|
|
||||||
void SetHeaderSize(const uint64_t &header_size) { header_size_ = header_size; }
|
void SetHeaderSize(const uint64_t &header_size) { header_size_ = header_size; }
|
||||||
|
|
||||||
void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; }
|
void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; }
|
||||||
|
|
||||||
|
void SetCompressionSize(const uint64_t &compression_size) { compression_size_ = compression_size; }
|
||||||
|
|
||||||
std::vector<std::string> SerializeHeader();
|
std::vector<std::string> SerializeHeader();
|
||||||
|
|
||||||
MSRStatus PagesToFile(const std::string dump_file_name);
|
MSRStatus PagesToFile(const std::string dump_file_name);
|
||||||
|
@ -177,6 +186,7 @@ class ShardHeader {
|
||||||
uint32_t shard_count_;
|
uint32_t shard_count_;
|
||||||
uint64_t header_size_;
|
uint64_t header_size_;
|
||||||
uint64_t page_size_;
|
uint64_t page_size_;
|
||||||
|
uint64_t compression_size_;
|
||||||
|
|
||||||
std::shared_ptr<Index> index_;
|
std::shared_ptr<Index> index_;
|
||||||
std::vector<std::string> shard_addresses_;
|
std::vector<std::string> shard_addresses_;
|
||||||
|
|
|
@ -209,6 +209,9 @@ class ShardReader {
|
||||||
/// \brief get all classes
|
/// \brief get all classes
|
||||||
MSRStatus GetAllClasses(const std::string &category_field, std::set<std::string> &categories);
|
MSRStatus GetAllClasses(const std::string &category_field, std::set<std::string> &categories);
|
||||||
|
|
||||||
|
/// \brief get the size of blob data
|
||||||
|
MSRStatus GetTotalBlobSize(int64_t *total_blob_size);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
/// \brief sqlite call back function
|
/// \brief sqlite call back function
|
||||||
static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names);
|
static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names);
|
||||||
|
@ -323,6 +326,7 @@ class ShardReader {
|
||||||
const std::string kThreadName = "THRD_ITER_"; // prefix of thread name
|
const std::string kThreadName = "THRD_ITER_"; // prefix of thread name
|
||||||
std::vector<std::thread> thread_set_; // thread list
|
std::vector<std::thread> thread_set_; // thread list
|
||||||
int num_rows_; // number of rows
|
int num_rows_; // number of rows
|
||||||
|
int64_t total_blob_size_; // total size of blob data
|
||||||
std::mutex mtx_delivery_; // locker for delivery
|
std::mutex mtx_delivery_; // locker for delivery
|
||||||
std::condition_variable cv_delivery_; // conditional variable for delivery
|
std::condition_variable cv_delivery_; // conditional variable for delivery
|
||||||
std::condition_variable cv_iterator_; // conditional variable for iterator
|
std::condition_variable cv_iterator_; // conditional variable for iterator
|
||||||
|
|
|
@ -257,6 +257,7 @@ class ShardWriter {
|
||||||
|
|
||||||
std::mutex check_mutex_; // mutex for data check
|
std::mutex check_mutex_; // mutex for data check
|
||||||
std::atomic<bool> flag_{false};
|
std::atomic<bool> flag_{false};
|
||||||
|
std::atomic<int64_t> compression_size_;
|
||||||
};
|
};
|
||||||
} // namespace mindrecord
|
} // namespace mindrecord
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -43,6 +43,7 @@ ShardReader::ShardReader() {
|
||||||
page_size_ = 0;
|
page_size_ = 0;
|
||||||
header_size_ = 0;
|
header_size_ = 0;
|
||||||
num_rows_ = 0;
|
num_rows_ = 0;
|
||||||
|
total_blob_size_ = 0;
|
||||||
num_padded_ = 0;
|
num_padded_ = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,9 +56,11 @@ std::pair<MSRStatus, std::vector<std::string>> ShardReader::GetMeta(const std::s
|
||||||
return {FAILED, {}};
|
return {FAILED, {}};
|
||||||
}
|
}
|
||||||
auto header = ret.second;
|
auto header = ret.second;
|
||||||
meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]},
|
uint64_t compression_size = header.contains("compression_size") ? header["compression_size"].get<uint64_t>() : 0;
|
||||||
{"version", header["version"]}, {"index_fields", header["index_fields"]},
|
meta_data = {{"header_size", header["header_size"]}, {"page_size", header["page_size"]},
|
||||||
{"schema", header["schema"]}, {"blob_fields", header["blob_fields"]}};
|
{"compression_size", compression_size}, {"version", header["version"]},
|
||||||
|
{"index_fields", header["index_fields"]}, {"schema", header["schema"]},
|
||||||
|
{"blob_fields", header["blob_fields"]}};
|
||||||
return {SUCCESS, header["shard_addresses"]};
|
return {SUCCESS, header["shard_addresses"]};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -145,6 +148,11 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa
|
||||||
for (const auto &rg : row_group_summary) {
|
for (const auto &rg : row_group_summary) {
|
||||||
num_rows_ += std::get<3>(rg);
|
num_rows_ += std::get<3>(rg);
|
||||||
}
|
}
|
||||||
|
auto disk_size = page_size_ * row_group_summary.size();
|
||||||
|
auto compression_size = shard_header_->GetCompressionSize();
|
||||||
|
total_blob_size_ = disk_size + compression_size;
|
||||||
|
MS_LOG(INFO) << "Blob data size, on disk: " << disk_size << " , addtional uncompression: " << compression_size
|
||||||
|
<< " , Total: " << total_blob_size_;
|
||||||
|
|
||||||
MS_LOG(INFO) << "Get meta from mindrecord file & index file successfully.";
|
MS_LOG(INFO) << "Get meta from mindrecord file & index file successfully.";
|
||||||
|
|
||||||
|
@ -272,6 +280,11 @@ std::vector<std::tuple<int, int, int, uint64_t>> ShardReader::ReadRowGroupSummar
|
||||||
return row_group_summary;
|
return row_group_summary;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MSRStatus ShardReader::GetTotalBlobSize(int64_t *total_blob_size) {
|
||||||
|
*total_blob_size = total_blob_size_;
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels,
|
MSRStatus ShardReader::ConvertLabelToJson(const std::vector<std::vector<std::string>> &labels,
|
||||||
std::shared_ptr<std::fstream> fs,
|
std::shared_ptr<std::fstream> fs,
|
||||||
std::vector<std::vector<std::vector<uint64_t>>> &offsets, int shard_id,
|
std::vector<std::vector<std::vector<uint64_t>>> &offsets, int shard_id,
|
||||||
|
|
|
@ -28,11 +28,9 @@ using mindspore::MsLogLevel::INFO;
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace mindrecord {
|
namespace mindrecord {
|
||||||
ShardWriter::ShardWriter()
|
ShardWriter::ShardWriter()
|
||||||
: shard_count_(1),
|
: shard_count_(1), header_size_(kDefaultHeaderSize), page_size_(kDefaultPageSize), row_count_(0), schema_count_(1) {
|
||||||
header_size_(kDefaultHeaderSize),
|
compression_size_ = 0;
|
||||||
page_size_(kDefaultPageSize),
|
}
|
||||||
row_count_(0),
|
|
||||||
schema_count_(1) {}
|
|
||||||
|
|
||||||
ShardWriter::~ShardWriter() {
|
ShardWriter::~ShardWriter() {
|
||||||
for (int i = static_cast<int>(file_streams_.size()) - 1; i >= 0; i--) {
|
for (int i = static_cast<int>(file_streams_.size()) - 1; i >= 0; i--) {
|
||||||
|
@ -201,6 +199,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
|
||||||
if (ret == FAILED) {
|
if (ret == FAILED) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
compression_size_ = shard_header_->GetCompressionSize();
|
||||||
ret = Open(real_addresses, true);
|
ret = Open(real_addresses, true);
|
||||||
if (ret == FAILED) {
|
if (ret == FAILED) {
|
||||||
MS_LOG(ERROR) << "Open file failed";
|
MS_LOG(ERROR) << "Open file failed";
|
||||||
|
@ -614,7 +613,9 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>
|
||||||
// compress blob
|
// compress blob
|
||||||
if (shard_column_->CheckCompressBlob()) {
|
if (shard_column_->CheckCompressBlob()) {
|
||||||
for (auto &blob : blob_data) {
|
for (auto &blob : blob_data) {
|
||||||
blob = shard_column_->CompressBlob(blob);
|
int64_t compression_bytes = 0;
|
||||||
|
blob = shard_column_->CompressBlob(blob, &compression_bytes);
|
||||||
|
compression_size_ += compression_bytes;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1177,6 +1178,11 @@ MSRStatus ShardWriter::WriteShardHeader() {
|
||||||
MS_LOG(ERROR) << "Shard header is null";
|
MS_LOG(ERROR) << "Shard header is null";
|
||||||
return FAILED;
|
return FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64_t compression_temp = compression_size_;
|
||||||
|
uint64_t compression_size = compression_temp > 0 ? compression_temp : 0;
|
||||||
|
shard_header_->SetCompressionSize(compression_size);
|
||||||
|
|
||||||
auto shard_header = shard_header_->SerializeHeader();
|
auto shard_header = shard_header_->SerializeHeader();
|
||||||
// Write header data to multi files
|
// Write header data to multi files
|
||||||
if (shard_count_ > static_cast<int>(file_streams_.size()) || shard_count_ > static_cast<int>(shard_header.size())) {
|
if (shard_count_ > static_cast<int>(file_streams_.size()) || shard_count_ > static_cast<int>(shard_header.size())) {
|
||||||
|
|
|
@ -24,7 +24,15 @@ namespace mindspore {
|
||||||
namespace mindrecord {
|
namespace mindrecord {
|
||||||
ShardColumn::ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer) {
|
ShardColumn::ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer) {
|
||||||
auto first_schema = shard_header->GetSchemas()[0];
|
auto first_schema = shard_header->GetSchemas()[0];
|
||||||
auto schema = first_schema->GetSchema()["schema"];
|
json schema_json = first_schema->GetSchema();
|
||||||
|
Init(schema_json, compress_integer);
|
||||||
|
}
|
||||||
|
|
||||||
|
ShardColumn::ShardColumn(const json &schema_json, bool compress_integer) { Init(schema_json, compress_integer); }
|
||||||
|
|
||||||
|
void ShardColumn::Init(const json &schema_json, bool compress_integer) {
|
||||||
|
auto schema = schema_json["schema"];
|
||||||
|
auto blob_fields = schema_json["blob_fields"];
|
||||||
|
|
||||||
bool has_integer_array = false;
|
bool has_integer_array = false;
|
||||||
for (json::iterator it = schema.begin(); it != schema.end(); ++it) {
|
for (json::iterator it = schema.begin(); it != schema.end(); ++it) {
|
||||||
|
@ -52,8 +60,6 @@ ShardColumn::ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool
|
||||||
column_name_id_[column_name_[i]] = i;
|
column_name_id_[column_name_[i]] = i;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto blob_fields = first_schema->GetBlobFields();
|
|
||||||
|
|
||||||
for (const auto &field : blob_fields) {
|
for (const auto &field : blob_fields) {
|
||||||
blob_column_.push_back(field);
|
blob_column_.push_back(field);
|
||||||
}
|
}
|
||||||
|
@ -282,8 +288,9 @@ ColumnCategory ShardColumn::CheckColumnName(const std::string &column_name) {
|
||||||
return it_blob == blob_column_id_.end() ? ColumnInRaw : ColumnInBlob;
|
return it_blob == blob_column_id_.end() ? ColumnInRaw : ColumnInBlob;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob) {
|
std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob, int64_t *compression_size) {
|
||||||
// Skip if no compress columns
|
// Skip if no compress columns
|
||||||
|
*compression_size = 0;
|
||||||
if (!CheckCompressBlob()) return blob;
|
if (!CheckCompressBlob()) return blob;
|
||||||
|
|
||||||
std::vector<uint8_t> dst_blob;
|
std::vector<uint8_t> dst_blob;
|
||||||
|
@ -295,7 +302,9 @@ std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob)
|
||||||
|
|
||||||
// Compress and return is blob has 1 column only
|
// Compress and return is blob has 1 column only
|
||||||
if (num_blob_column_ == 1) {
|
if (num_blob_column_ == 1) {
|
||||||
return CompressInt(blob, int_type);
|
dst_blob = CompressInt(blob, int_type);
|
||||||
|
*compression_size = static_cast<int64_t>(blob.size()) - static_cast<int64_t>(dst_blob.size());
|
||||||
|
return dst_blob;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Just copy and continue if column dat type is not int32/int64
|
// Just copy and continue if column dat type is not int32/int64
|
||||||
|
@ -319,6 +328,7 @@ std::vector<uint8_t> ShardColumn::CompressBlob(const std::vector<uint8_t> &blob)
|
||||||
i_src += kInt64Len + num_bytes;
|
i_src += kInt64Len + num_bytes;
|
||||||
}
|
}
|
||||||
MS_LOG(DEBUG) << "Compress all blob from " << blob.size() << " to " << dst_blob.size() << ".";
|
MS_LOG(DEBUG) << "Compress all blob from " << blob.size() << " to " << dst_blob.size() << ".";
|
||||||
|
*compression_size = static_cast<int64_t>(blob.size()) - static_cast<int64_t>(dst_blob.size());
|
||||||
return dst_blob;
|
return dst_blob;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,9 @@ using mindspore::MsLogLevel::ERROR;
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace mindrecord {
|
namespace mindrecord {
|
||||||
std::atomic<bool> thread_status(false);
|
std::atomic<bool> thread_status(false);
|
||||||
ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0) { index_ = std::make_shared<Index>(); }
|
ShardHeader::ShardHeader() : shard_count_(0), header_size_(0), page_size_(0), compression_size_(0) {
|
||||||
|
index_ = std::make_shared<Index>();
|
||||||
|
}
|
||||||
|
|
||||||
MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool load_dataset) {
|
MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool load_dataset) {
|
||||||
shard_count_ = headers.size();
|
shard_count_ = headers.size();
|
||||||
|
@ -54,6 +56,7 @@ MSRStatus ShardHeader::InitializeHeader(const std::vector<json> &headers, bool l
|
||||||
ParseShardAddress(header["shard_addresses"]);
|
ParseShardAddress(header["shard_addresses"]);
|
||||||
header_size_ = header["header_size"].get<uint64_t>();
|
header_size_ = header["header_size"].get<uint64_t>();
|
||||||
page_size_ = header["page_size"].get<uint64_t>();
|
page_size_ = header["page_size"].get<uint64_t>();
|
||||||
|
compression_size_ = header.contains("compression_size") ? header["compression_size"].get<uint64_t>() : 0;
|
||||||
}
|
}
|
||||||
if (SUCCESS != ParsePage(header["page"], shard_index, load_dataset)) {
|
if (SUCCESS != ParsePage(header["page"], shard_index, load_dataset)) {
|
||||||
return FAILED;
|
return FAILED;
|
||||||
|
@ -146,9 +149,12 @@ std::pair<MSRStatus, json> ShardHeader::BuildSingleHeader(const std::string &fil
|
||||||
return {FAILED, json()};
|
return {FAILED, json()};
|
||||||
}
|
}
|
||||||
json raw_header = ret.second;
|
json raw_header = ret.second;
|
||||||
|
uint64_t compression_size =
|
||||||
|
raw_header.contains("compression_size") ? raw_header["compression_size"].get<uint64_t>() : 0;
|
||||||
json header = {{"shard_addresses", raw_header["shard_addresses"]},
|
json header = {{"shard_addresses", raw_header["shard_addresses"]},
|
||||||
{"header_size", raw_header["header_size"]},
|
{"header_size", raw_header["header_size"]},
|
||||||
{"page_size", raw_header["page_size"]},
|
{"page_size", raw_header["page_size"]},
|
||||||
|
{"compression_size", compression_size},
|
||||||
{"index_fields", raw_header["index_fields"]},
|
{"index_fields", raw_header["index_fields"]},
|
||||||
{"blob_fields", raw_header["schema"][0]["blob_fields"]},
|
{"blob_fields", raw_header["schema"][0]["blob_fields"]},
|
||||||
{"schema", raw_header["schema"][0]["schema"]},
|
{"schema", raw_header["schema"][0]["schema"]},
|
||||||
|
@ -343,6 +349,7 @@ std::vector<std::string> ShardHeader::SerializeHeader() {
|
||||||
s += "\"index_fields\":" + index + ",";
|
s += "\"index_fields\":" + index + ",";
|
||||||
s += "\"page\":" + pages[shardId] + ",";
|
s += "\"page\":" + pages[shardId] + ",";
|
||||||
s += "\"page_size\":" + std::to_string(page_size_) + ",";
|
s += "\"page_size\":" + std::to_string(page_size_) + ",";
|
||||||
|
s += "\"compression_size\":" + std::to_string(compression_size_) + ",";
|
||||||
s += "\"schema\":" + schema + ",";
|
s += "\"schema\":" + schema + ",";
|
||||||
s += "\"shard_addresses\":" + address + ",";
|
s += "\"shard_addresses\":" + address + ",";
|
||||||
s += "\"shard_id\":" + std::to_string(shardId) + ",";
|
s += "\"shard_id\":" + std::to_string(shardId) + ",";
|
||||||
|
|
|
@ -3083,20 +3083,22 @@ def _cpp_sampler_fn(sampler, dataset):
|
||||||
yield tuple([np.array(x, copy=False) for x in val])
|
yield tuple([np.array(x, copy=False) for x in val])
|
||||||
|
|
||||||
|
|
||||||
def _cpp_sampler_fn_mp(sampler, dataset, num_worker):
|
def _cpp_sampler_fn_mp(sampler, dataset, num_worker, multi_process):
|
||||||
"""
|
"""
|
||||||
Multiprocessing generator function wrapper for mappable dataset with cpp sampler.
|
Multiprocessing generator function wrapper for mappable dataset with cpp sampler.
|
||||||
"""
|
"""
|
||||||
indices = sampler.get_indices()
|
indices = sampler.get_indices()
|
||||||
return _sampler_fn_mp(indices, dataset, num_worker)
|
sample_fn = SamplerFn(dataset, num_worker, multi_process)
|
||||||
|
return sample_fn.process(indices)
|
||||||
|
|
||||||
|
|
||||||
def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker):
|
def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker, multi_process):
|
||||||
"""
|
"""
|
||||||
Multiprocessing generator function wrapper for mappable dataset with python sampler.
|
Multiprocessing generator function wrapper for mappable dataset with python sampler.
|
||||||
"""
|
"""
|
||||||
indices = _fetch_py_sampler_indices(sampler, num_samples)
|
indices = _fetch_py_sampler_indices(sampler, num_samples)
|
||||||
return _sampler_fn_mp(indices, dataset, num_worker)
|
sample_fn = SamplerFn(dataset, num_worker, multi_process)
|
||||||
|
return sample_fn.process(indices)
|
||||||
|
|
||||||
|
|
||||||
def _fetch_py_sampler_indices(sampler, num_samples):
|
def _fetch_py_sampler_indices(sampler, num_samples):
|
||||||
|
@ -3130,63 +3132,92 @@ def _fill_worker_indices(workers, indices, idx):
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
|
|
||||||
def _sampler_fn_mp(indices, dataset, num_worker):
|
class SamplerFn:
|
||||||
"""
|
"""
|
||||||
Multiprocessing generator function wrapper master process.
|
Multiprocessing or multithread generator function wrapper master process.
|
||||||
"""
|
"""
|
||||||
workers = []
|
def __init__(self, dataset, num_worker, multi_process):
|
||||||
# Event for end of epoch
|
self.workers = []
|
||||||
eoe = multiprocessing.Event()
|
self.num_worker = num_worker
|
||||||
|
self.multi_process = multi_process
|
||||||
|
# Event for end of epoch
|
||||||
|
if multi_process is True:
|
||||||
|
self.eoe = multiprocessing.Event()
|
||||||
|
self.eof = multiprocessing.Event()
|
||||||
|
else:
|
||||||
|
self.eoe = threading.Event()
|
||||||
|
self.eof = threading.Event()
|
||||||
|
# Create workers
|
||||||
|
for _ in range(num_worker):
|
||||||
|
if multi_process is True:
|
||||||
|
worker = _GeneratorWorkerMp(dataset, self.eoe, self.eof)
|
||||||
|
else:
|
||||||
|
worker = _GeneratorWorkerMt(dataset, self.eoe, self.eof)
|
||||||
|
worker.daemon = True
|
||||||
|
self.workers.append(worker)
|
||||||
|
|
||||||
# Create workers
|
def process(self, indices):
|
||||||
for _ in range(num_worker):
|
"""
|
||||||
worker = _GeneratorWorker(dataset, eoe)
|
The main process, start the child process or child thread, and fill the index queue,
|
||||||
worker.daemon = True
|
get the result from the result and return.
|
||||||
workers.append(worker)
|
"""
|
||||||
|
# Fill initial index queues
|
||||||
|
idx_cursor = 0
|
||||||
|
idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)
|
||||||
|
|
||||||
# Fill initial index queues
|
# Start all workers
|
||||||
idx_cursor = 0
|
for w in self.workers:
|
||||||
idx_cursor = _fill_worker_indices(workers, indices, idx_cursor)
|
w.start()
|
||||||
|
|
||||||
# Start all workers
|
# Fetch results
|
||||||
for w in workers:
|
for i in range(len(indices)):
|
||||||
w.start()
|
# Fetch result and put index
|
||||||
|
try:
|
||||||
|
result = self.workers[i % self.num_worker].get()
|
||||||
|
except queue.Empty:
|
||||||
|
raise Exception("Generator worker process timeout")
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
self.eof.set()
|
||||||
|
for w in self.workers:
|
||||||
|
w.terminate()
|
||||||
|
w.join()
|
||||||
|
raise Exception("Generator worker receives KeyboardInterrupt")
|
||||||
|
if idx_cursor < len(indices):
|
||||||
|
idx_cursor = _fill_worker_indices(self.workers, indices, idx_cursor)
|
||||||
|
# Set eoe event once all indices are sent
|
||||||
|
if idx_cursor == len(indices) and not self.eoe.is_set():
|
||||||
|
self.eoe.set()
|
||||||
|
yield tuple([np.array(x, copy=False) for x in result])
|
||||||
|
|
||||||
# Fetch results
|
def __del__(self):
|
||||||
for i in range(len(indices)):
|
self.eoe.set()
|
||||||
# Fetch result and put index
|
self.eof.set()
|
||||||
try:
|
if self.multi_process is False:
|
||||||
result = workers[i % num_worker].get()
|
for w in self.workers:
|
||||||
except queue.Empty:
|
|
||||||
raise Exception("Generator worker process timeout")
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
for w in workers:
|
|
||||||
w.terminate()
|
|
||||||
w.join()
|
w.join()
|
||||||
raise Exception("Generator worker receives KeyboardInterrupt")
|
|
||||||
if idx_cursor < len(indices):
|
|
||||||
idx_cursor = _fill_worker_indices(workers, indices, idx_cursor)
|
|
||||||
# Set eoe event once all indices are sent
|
|
||||||
if idx_cursor == len(indices) and not eoe.is_set():
|
|
||||||
eoe.set()
|
|
||||||
yield tuple([np.array(x, copy=False) for x in result])
|
|
||||||
|
|
||||||
|
|
||||||
def _generator_worker_loop(dataset, idx_queue, result_queue, eoe):
|
def _generator_worker_loop(dataset, idx_queue, result_queue, eoe, eof):
|
||||||
"""
|
"""
|
||||||
Multiprocessing generator worker process loop.
|
Multiprocessing or multithread generator worker process loop.
|
||||||
"""
|
"""
|
||||||
while True:
|
while True:
|
||||||
# Fetch index, block
|
# Fetch index, block
|
||||||
try:
|
try:
|
||||||
idx = idx_queue.get()
|
idx = idx_queue.get(timeout=10)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
raise Exception("Generator worker receives KeyboardInterrupt")
|
raise Exception("Generator worker receives KeyboardInterrupt")
|
||||||
|
except queue.Empty:
|
||||||
|
if eof.is_set() or eoe.is_set():
|
||||||
|
raise Exception("Generator worker receives queue.Empty")
|
||||||
|
continue
|
||||||
if idx is None:
|
if idx is None:
|
||||||
# When the queue is out of scope from master process, a None item can be fetched from the queue.
|
# When the queue is out of scope from master process, a None item can be fetched from the queue.
|
||||||
# Upon receiving None, worker process should check if EOE is set.
|
# Upon receiving None, worker process should check if EOE is set.
|
||||||
assert eoe.is_set(), ""
|
assert eoe.is_set(), ""
|
||||||
return
|
return
|
||||||
|
if eof.is_set():
|
||||||
|
return
|
||||||
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process
|
# Fetch data, any exception from __getitem__ will terminate worker and timeout master process
|
||||||
result = dataset[idx]
|
result = dataset[idx]
|
||||||
# Send data, block
|
# Send data, block
|
||||||
|
@ -3195,17 +3226,19 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eoe):
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
raise Exception("Generator worker receives KeyboardInterrupt")
|
raise Exception("Generator worker receives KeyboardInterrupt")
|
||||||
del result, idx
|
del result, idx
|
||||||
|
if eoe.is_set() and idx_queue.empty():
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
class _GeneratorWorker(multiprocessing.Process):
|
class _GeneratorWorkerMt(threading.Thread):
|
||||||
"""
|
"""
|
||||||
Worker process for multiprocess Generator.
|
Worker process for multithread Generator.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dataset, eoe):
|
def __init__(self, dataset, eoe, eof):
|
||||||
self.idx_queue = multiprocessing.Queue(16)
|
self.idx_queue = queue.Queue(16)
|
||||||
self.res_queue = multiprocessing.Queue(16)
|
self.res_queue = queue.Queue(16)
|
||||||
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe))
|
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof))
|
||||||
|
|
||||||
def put(self, item):
|
def put(self, item):
|
||||||
"""
|
"""
|
||||||
|
@ -3217,7 +3250,30 @@ class _GeneratorWorker(multiprocessing.Process):
|
||||||
"""
|
"""
|
||||||
Get function for worker result queue. Block with timeout.
|
Get function for worker result queue. Block with timeout.
|
||||||
"""
|
"""
|
||||||
return self.res_queue.get()
|
return self.res_queue.get(timeout=10)
|
||||||
|
|
||||||
|
|
||||||
|
class _GeneratorWorkerMp(multiprocessing.Process):
|
||||||
|
"""
|
||||||
|
Worker process for multiprocess Generator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dataset, eoe, eof):
|
||||||
|
self.idx_queue = multiprocessing.Queue(16)
|
||||||
|
self.res_queue = multiprocessing.Queue(16)
|
||||||
|
super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe, eof))
|
||||||
|
|
||||||
|
def put(self, item):
|
||||||
|
"""
|
||||||
|
Put function for worker index queue. Never block. Raise queue.Full on failure.
|
||||||
|
"""
|
||||||
|
self.idx_queue.put_nowait(item)
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
"""
|
||||||
|
Get function for worker result queue. Block with timeout.
|
||||||
|
"""
|
||||||
|
return self.res_queue.get(timeout=10)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.terminate()
|
self.terminate()
|
||||||
|
@ -3280,6 +3336,8 @@ class GeneratorDataset(MappableDataset):
|
||||||
When this argument is specified, 'num_samples' will not effect. Random accessible input is required.
|
When this argument is specified, 'num_samples' will not effect. Random accessible input is required.
|
||||||
shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only
|
shard_id (int, optional): The shard ID within num_shards (default=None). This argument should be specified only
|
||||||
when num_shards is also specified. Random accessible input is required.
|
when num_shards is also specified. Random accessible input is required.
|
||||||
|
python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This
|
||||||
|
option could be beneficial if the python operation is computational heavy (default=True).
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> import mindspore.dataset as ds
|
>>> import mindspore.dataset as ds
|
||||||
|
@ -3316,12 +3374,14 @@ class GeneratorDataset(MappableDataset):
|
||||||
|
|
||||||
@check_generatordataset
|
@check_generatordataset
|
||||||
def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None,
|
def __init__(self, source, column_names=None, column_types=None, schema=None, num_samples=None,
|
||||||
num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None):
|
num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None,
|
||||||
|
python_multiprocessing=True):
|
||||||
super().__init__(num_parallel_workers)
|
super().__init__(num_parallel_workers)
|
||||||
self.source = source
|
self.source = source
|
||||||
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id)
|
||||||
self.num_samples = num_samples
|
self.num_samples = num_samples
|
||||||
self.num_shards = num_shards
|
self.num_shards = num_shards
|
||||||
|
self.python_multiprocessing = python_multiprocessing
|
||||||
|
|
||||||
if column_names is not None and not isinstance(column_names, list):
|
if column_names is not None and not isinstance(column_names, list):
|
||||||
column_names = [column_names]
|
column_names = [column_names]
|
||||||
|
@ -3403,12 +3463,16 @@ class GeneratorDataset(MappableDataset):
|
||||||
sampler_instance.set_num_rows(len(self.source))
|
sampler_instance.set_num_rows(len(self.source))
|
||||||
sampler_instance.initialize()
|
sampler_instance.initialize()
|
||||||
if new_op.num_parallel_workers > 1:
|
if new_op.num_parallel_workers > 1:
|
||||||
new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, self.source, new_op.num_parallel_workers))
|
new_op.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, self.source,
|
||||||
|
new_op.num_parallel_workers,
|
||||||
|
self.python_multiprocessing))
|
||||||
else:
|
else:
|
||||||
new_op.source = (lambda: _cpp_sampler_fn(sampler_instance, self.source))
|
new_op.source = (lambda: _cpp_sampler_fn(sampler_instance, self.source))
|
||||||
else:
|
else:
|
||||||
if new_op.num_parallel_workers > 1:
|
if new_op.num_parallel_workers > 1:
|
||||||
new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, self.source, new_op.num_parallel_workers))
|
new_op.source = (lambda: _py_sampler_fn_mp(new_op.sampler, new_op.num_samples, self.source,
|
||||||
|
new_op.num_parallel_workers,
|
||||||
|
self.python_multiprocessing))
|
||||||
else:
|
else:
|
||||||
new_op.source = (lambda: _py_sampler_fn(new_op.sampler, new_op.num_samples, self.source))
|
new_op.source = (lambda: _py_sampler_fn(new_op.sampler, new_op.num_samples, self.source))
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -16,8 +16,11 @@
|
||||||
graphdata.py supports loading graph dataset for GNN network training,
|
graphdata.py supports loading graph dataset for GNN network training,
|
||||||
and provides operations related to graph data.
|
and provides operations related to graph data.
|
||||||
"""
|
"""
|
||||||
|
import atexit
|
||||||
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mindspore._c_dataengine import Graph
|
from mindspore._c_dataengine import GraphDataClient
|
||||||
|
from mindspore._c_dataengine import GraphDataServer
|
||||||
from mindspore._c_dataengine import Tensor
|
from mindspore._c_dataengine import Tensor
|
||||||
|
|
||||||
from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \
|
from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \
|
||||||
|
@ -34,14 +37,52 @@ class GraphData:
|
||||||
dataset_file (str): One of file names in dataset.
|
dataset_file (str): One of file names in dataset.
|
||||||
num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel
|
num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel
|
||||||
(default=None).
|
(default=None).
|
||||||
|
working_mode (str, optional): Set working mode, now support 'local'/'client'/'server' (default='local').
|
||||||
|
|
||||||
|
- 'local', used in non-distributed training scenarios.
|
||||||
|
|
||||||
|
- 'client', used in distributed training scenarios, the client does not load data,
|
||||||
|
but obtains data from the server.
|
||||||
|
|
||||||
|
- 'server', used in distributed training scenarios, the server loads the data
|
||||||
|
and is available to the client.
|
||||||
|
|
||||||
|
hostname (str, optional): Valid when working_mode is set to 'client' or 'server',
|
||||||
|
set the hostname of the graph data server (default='127.0.0.1').
|
||||||
|
port (int, optional): Valid when working_mode is set to 'client' or 'server',
|
||||||
|
set the port of the graph data server, the range is 1024-65535 (default=50051).
|
||||||
|
num_client (int, optional): Valid when working_mode is set to 'server',
|
||||||
|
set the number of clients expected to connect, and the server will allocate corresponding
|
||||||
|
resources according to this parameter (default=1).
|
||||||
|
auto_shutdown (bool, optional): Valid when working_mode is set to 'server',
|
||||||
|
Control when all clients have connected and no client connected to the server,
|
||||||
|
automatically exit the server (default=True).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@check_gnn_graphdata
|
@check_gnn_graphdata
|
||||||
def __init__(self, dataset_file, num_parallel_workers=None):
|
def __init__(self, dataset_file, num_parallel_workers=None, working_mode='local', hostname='127.0.0.1', port=50051,
|
||||||
|
num_client=1, auto_shutdown=True):
|
||||||
self._dataset_file = dataset_file
|
self._dataset_file = dataset_file
|
||||||
|
self._working_mode = working_mode
|
||||||
if num_parallel_workers is None:
|
if num_parallel_workers is None:
|
||||||
num_parallel_workers = 1
|
num_parallel_workers = 1
|
||||||
self._graph = Graph(dataset_file, num_parallel_workers)
|
|
||||||
|
def stop():
|
||||||
|
self._graph_data.stop()
|
||||||
|
atexit.register(stop)
|
||||||
|
|
||||||
|
if working_mode in ['local', 'client']:
|
||||||
|
self._graph_data = GraphDataClient(dataset_file, num_parallel_workers, working_mode, hostname, port)
|
||||||
|
|
||||||
|
if working_mode == 'server':
|
||||||
|
self._graph_data = GraphDataServer(
|
||||||
|
dataset_file, num_parallel_workers, hostname, port, num_client, auto_shutdown)
|
||||||
|
try:
|
||||||
|
while self._graph_data.is_stoped() is not True:
|
||||||
|
time.sleep(1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
# self._graph_data.stop()
|
||||||
|
raise Exception("Graph data server receives KeyboardInterrupt")
|
||||||
|
|
||||||
@check_gnn_get_all_nodes
|
@check_gnn_get_all_nodes
|
||||||
def get_all_nodes(self, node_type):
|
def get_all_nodes(self, node_type):
|
||||||
|
@ -62,7 +103,9 @@ class GraphData:
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If `node_type` is not integer.
|
TypeError: If `node_type` is not integer.
|
||||||
"""
|
"""
|
||||||
return self._graph.get_all_nodes(node_type).as_array()
|
if self._working_mode == 'server':
|
||||||
|
raise Exception("This method is not supported when working mode is server")
|
||||||
|
return self._graph_data.get_all_nodes(node_type).as_array()
|
||||||
|
|
||||||
@check_gnn_get_all_edges
|
@check_gnn_get_all_edges
|
||||||
def get_all_edges(self, edge_type):
|
def get_all_edges(self, edge_type):
|
||||||
|
@ -83,7 +126,9 @@ class GraphData:
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If `edge_type` is not integer.
|
TypeError: If `edge_type` is not integer.
|
||||||
"""
|
"""
|
||||||
return self._graph.get_all_edges(edge_type).as_array()
|
if self._working_mode == 'server':
|
||||||
|
raise Exception("This method is not supported when working mode is server")
|
||||||
|
return self._graph_data.get_all_edges(edge_type).as_array()
|
||||||
|
|
||||||
@check_gnn_get_nodes_from_edges
|
@check_gnn_get_nodes_from_edges
|
||||||
def get_nodes_from_edges(self, edge_list):
|
def get_nodes_from_edges(self, edge_list):
|
||||||
|
@ -99,7 +144,9 @@ class GraphData:
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If `edge_list` is not list or ndarray.
|
TypeError: If `edge_list` is not list or ndarray.
|
||||||
"""
|
"""
|
||||||
return self._graph.get_nodes_from_edges(edge_list).as_array()
|
if self._working_mode == 'server':
|
||||||
|
raise Exception("This method is not supported when working mode is server")
|
||||||
|
return self._graph_data.get_nodes_from_edges(edge_list).as_array()
|
||||||
|
|
||||||
@check_gnn_get_all_neighbors
|
@check_gnn_get_all_neighbors
|
||||||
def get_all_neighbors(self, node_list, neighbor_type):
|
def get_all_neighbors(self, node_list, neighbor_type):
|
||||||
|
@ -123,7 +170,9 @@ class GraphData:
|
||||||
TypeError: If `node_list` is not list or ndarray.
|
TypeError: If `node_list` is not list or ndarray.
|
||||||
TypeError: If `neighbor_type` is not integer.
|
TypeError: If `neighbor_type` is not integer.
|
||||||
"""
|
"""
|
||||||
return self._graph.get_all_neighbors(node_list, neighbor_type).as_array()
|
if self._working_mode == 'server':
|
||||||
|
raise Exception("This method is not supported when working mode is server")
|
||||||
|
return self._graph_data.get_all_neighbors(node_list, neighbor_type).as_array()
|
||||||
|
|
||||||
@check_gnn_get_sampled_neighbors
|
@check_gnn_get_sampled_neighbors
|
||||||
def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types):
|
def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types):
|
||||||
|
@ -155,7 +204,9 @@ class GraphData:
|
||||||
TypeError: If `neighbor_nums` is not list or ndarray.
|
TypeError: If `neighbor_nums` is not list or ndarray.
|
||||||
TypeError: If `neighbor_types` is not list or ndarray.
|
TypeError: If `neighbor_types` is not list or ndarray.
|
||||||
"""
|
"""
|
||||||
return self._graph.get_sampled_neighbors(
|
if self._working_mode == 'server':
|
||||||
|
raise Exception("This method is not supported when working mode is server")
|
||||||
|
return self._graph_data.get_sampled_neighbors(
|
||||||
node_list, neighbor_nums, neighbor_types).as_array()
|
node_list, neighbor_nums, neighbor_types).as_array()
|
||||||
|
|
||||||
@check_gnn_get_neg_sampled_neighbors
|
@check_gnn_get_neg_sampled_neighbors
|
||||||
|
@ -182,7 +233,9 @@ class GraphData:
|
||||||
TypeError: If `neg_neighbor_num` is not integer.
|
TypeError: If `neg_neighbor_num` is not integer.
|
||||||
TypeError: If `neg_neighbor_type` is not integer.
|
TypeError: If `neg_neighbor_type` is not integer.
|
||||||
"""
|
"""
|
||||||
return self._graph.get_neg_sampled_neighbors(
|
if self._working_mode == 'server':
|
||||||
|
raise Exception("This method is not supported when working mode is server")
|
||||||
|
return self._graph_data.get_neg_sampled_neighbors(
|
||||||
node_list, neg_neighbor_num, neg_neighbor_type).as_array()
|
node_list, neg_neighbor_num, neg_neighbor_type).as_array()
|
||||||
|
|
||||||
@check_gnn_get_node_feature
|
@check_gnn_get_node_feature
|
||||||
|
@ -207,10 +260,12 @@ class GraphData:
|
||||||
TypeError: If `node_list` is not list or ndarray.
|
TypeError: If `node_list` is not list or ndarray.
|
||||||
TypeError: If `feature_types` is not list or ndarray.
|
TypeError: If `feature_types` is not list or ndarray.
|
||||||
"""
|
"""
|
||||||
|
if self._working_mode == 'server':
|
||||||
|
raise Exception("This method is not supported when working mode is server")
|
||||||
if isinstance(node_list, list):
|
if isinstance(node_list, list):
|
||||||
node_list = np.array(node_list, dtype=np.int32)
|
node_list = np.array(node_list, dtype=np.int32)
|
||||||
return [
|
return [
|
||||||
t.as_array() for t in self._graph.get_node_feature(
|
t.as_array() for t in self._graph_data.get_node_feature(
|
||||||
Tensor(node_list),
|
Tensor(node_list),
|
||||||
feature_types)]
|
feature_types)]
|
||||||
|
|
||||||
|
@ -236,10 +291,12 @@ class GraphData:
|
||||||
TypeError: If `edge_list` is not list or ndarray.
|
TypeError: If `edge_list` is not list or ndarray.
|
||||||
TypeError: If `feature_types` is not list or ndarray.
|
TypeError: If `feature_types` is not list or ndarray.
|
||||||
"""
|
"""
|
||||||
|
if self._working_mode == 'server':
|
||||||
|
raise Exception("This method is not supported when working mode is server")
|
||||||
if isinstance(edge_list, list):
|
if isinstance(edge_list, list):
|
||||||
edge_list = np.array(edge_list, dtype=np.int32)
|
edge_list = np.array(edge_list, dtype=np.int32)
|
||||||
return [
|
return [
|
||||||
t.as_array() for t in self._graph.get_edge_feature(
|
t.as_array() for t in self._graph_data.get_edge_feature(
|
||||||
Tensor(edge_list),
|
Tensor(edge_list),
|
||||||
feature_types)]
|
feature_types)]
|
||||||
|
|
||||||
|
@ -252,7 +309,9 @@ class GraphData:
|
||||||
dict: Meta information of the graph. The key is node_type, edge_type, node_num, edge_num,
|
dict: Meta information of the graph. The key is node_type, edge_type, node_num, edge_num,
|
||||||
node_feature_type and edge_feature_type.
|
node_feature_type and edge_feature_type.
|
||||||
"""
|
"""
|
||||||
return self._graph.graph_info()
|
if self._working_mode == 'server':
|
||||||
|
raise Exception("This method is not supported when working mode is server")
|
||||||
|
return self._graph_data.graph_info()
|
||||||
|
|
||||||
@check_gnn_random_walk
|
@check_gnn_random_walk
|
||||||
def random_walk(
|
def random_walk(
|
||||||
|
@ -285,5 +344,7 @@ class GraphData:
|
||||||
TypeError: If `target_nodes` is not list or ndarray.
|
TypeError: If `target_nodes` is not list or ndarray.
|
||||||
TypeError: If `meta_path` is not list or ndarray.
|
TypeError: If `meta_path` is not list or ndarray.
|
||||||
"""
|
"""
|
||||||
return self._graph.random_walk(target_nodes, meta_path, step_home_param, step_away_param,
|
if self._working_mode == 'server':
|
||||||
default_node).as_array()
|
raise Exception("This method is not supported when working mode is server")
|
||||||
|
return self._graph_data.random_walk(target_nodes, meta_path, step_home_param, step_away_param,
|
||||||
|
default_node).as_array()
|
||||||
|
|
|
@ -18,6 +18,7 @@ Built-in validators.
|
||||||
"""
|
"""
|
||||||
import inspect as ins
|
import inspect as ins
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -912,16 +913,36 @@ def check_split(method):
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
|
def check_hostname(hostname):
|
||||||
|
if len(hostname) > 255:
|
||||||
|
return False
|
||||||
|
if hostname[-1] == ".":
|
||||||
|
hostname = hostname[:-1] # strip exactly one dot from the right, if present
|
||||||
|
allowed = re.compile("(?!-)[A-Z\\d-]{1,63}(?<!-)$", re.IGNORECASE)
|
||||||
|
return all(allowed.match(x) for x in hostname.split("."))
|
||||||
|
|
||||||
|
|
||||||
def check_gnn_graphdata(method):
|
def check_gnn_graphdata(method):
|
||||||
"""check the input arguments of graphdata."""
|
"""check the input arguments of graphdata."""
|
||||||
|
|
||||||
@wraps(method)
|
@wraps(method)
|
||||||
def new_method(self, *args, **kwargs):
|
def new_method(self, *args, **kwargs):
|
||||||
[dataset_file, num_parallel_workers], _ = parse_user_args(method, *args, **kwargs)
|
[dataset_file, num_parallel_workers, working_mode, hostname,
|
||||||
|
port, num_client, auto_shutdown], _ = parse_user_args(method, *args, **kwargs)
|
||||||
check_file(dataset_file)
|
check_file(dataset_file)
|
||||||
|
|
||||||
if num_parallel_workers is not None:
|
if num_parallel_workers is not None:
|
||||||
check_num_parallel_workers(num_parallel_workers)
|
check_num_parallel_workers(num_parallel_workers)
|
||||||
|
type_check(hostname, (str,), "hostname")
|
||||||
|
if check_hostname(hostname) is False:
|
||||||
|
raise ValueError("The hostname is illegal")
|
||||||
|
type_check(working_mode, (str,), "working_mode")
|
||||||
|
if working_mode not in {'local', 'client', 'server'}:
|
||||||
|
raise ValueError("Invalid working mode")
|
||||||
|
type_check(port, (int,), "port")
|
||||||
|
check_value(port, (1024, 65535), "port")
|
||||||
|
type_check(num_client, (int,), "num_client")
|
||||||
|
check_value(num_client, (1, 255), "num_client")
|
||||||
|
type_check(auto_shutdown, (bool,), "auto_shutdown")
|
||||||
return method(self, *args, **kwargs)
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
"""
|
"""
|
||||||
User-defined API for MindRecord GNN writer.
|
User-defined API for MindRecord GNN writer.
|
||||||
"""
|
"""
|
||||||
|
import numpy as np
|
||||||
social_data = [[348, 350], [348, 327], [348, 329], [348, 331], [348, 335],
|
social_data = [[348, 350], [348, 327], [348, 329], [348, 331], [348, 335],
|
||||||
[348, 336], [348, 337], [348, 338], [348, 340], [348, 341],
|
[348, 336], [348, 337], [348, 338], [348, 340], [348, 341],
|
||||||
[348, 342], [348, 343], [348, 344], [348, 345], [348, 346],
|
[348, 342], [348, 343], [348, 344], [348, 345], [348, 346],
|
||||||
|
@ -29,7 +30,7 @@ social_data = [[348, 350], [348, 327], [348, 329], [348, 331], [348, 335],
|
||||||
[355, 352], [353, 350], [352, 349], [351, 349], [350, 349]]
|
[355, 352], [353, 350], [352, 349], [351, 349], [350, 349]]
|
||||||
|
|
||||||
# profile: (num_features, feature_data_types, feature_shapes)
|
# profile: (num_features, feature_data_types, feature_shapes)
|
||||||
node_profile = (0, [], [])
|
node_profile = (2, ["int64", "int32"], [[-1], [-1]])
|
||||||
edge_profile = (0, [], [])
|
edge_profile = (0, [], [])
|
||||||
|
|
||||||
|
|
||||||
|
@ -51,7 +52,9 @@ def yield_nodes(task_id=0):
|
||||||
node_list.sort()
|
node_list.sort()
|
||||||
print(node_list)
|
print(node_list)
|
||||||
for node_id in node_list:
|
for node_id in node_list:
|
||||||
node = {'id': node_id, 'type': 1}
|
node = {'id': node_id, 'type': 1,
|
||||||
|
'feature_1': np.ones((5,), dtype=np.int64),
|
||||||
|
'feature_2': np.ones((10,), dtype=np.int32)}
|
||||||
yield node
|
yield node
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
#include "minddata/dataset/util/status.h"
|
#include "minddata/dataset/util/status.h"
|
||||||
#include "minddata/dataset/engine/gnn/node.h"
|
#include "minddata/dataset/engine/gnn/node.h"
|
||||||
|
#include "minddata/dataset/engine/gnn/graph_data_impl.h"
|
||||||
#include "minddata/dataset/engine/gnn/graph_loader.h"
|
#include "minddata/dataset/engine/gnn/graph_loader.h"
|
||||||
|
|
||||||
using namespace mindspore::dataset;
|
using namespace mindspore::dataset;
|
||||||
|
@ -39,30 +40,9 @@ class MindDataTestGNNGraph : public UT::Common {
|
||||||
MindDataTestGNNGraph() = default;
|
MindDataTestGNNGraph() = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(MindDataTestGNNGraph, TestGraphLoader) {
|
|
||||||
std::string path = "data/mindrecord/testGraphData/testdata";
|
|
||||||
GraphLoader gl(path, 4);
|
|
||||||
EXPECT_TRUE(gl.InitAndLoad().IsOk());
|
|
||||||
NodeIdMap n_id_map;
|
|
||||||
EdgeIdMap e_id_map;
|
|
||||||
NodeTypeMap n_type_map;
|
|
||||||
EdgeTypeMap e_type_map;
|
|
||||||
NodeFeatureMap n_feature_map;
|
|
||||||
EdgeFeatureMap e_feature_map;
|
|
||||||
DefaultNodeFeatureMap default_node_feature_map;
|
|
||||||
DefaultEdgeFeatureMap default_edge_feature_map;
|
|
||||||
EXPECT_TRUE(gl.GetNodesAndEdges(&n_id_map, &e_id_map, &n_type_map, &e_type_map, &n_feature_map, &e_feature_map,
|
|
||||||
&default_node_feature_map, &default_edge_feature_map)
|
|
||||||
.IsOk());
|
|
||||||
EXPECT_EQ(n_id_map.size(), 20);
|
|
||||||
EXPECT_EQ(e_id_map.size(), 40);
|
|
||||||
EXPECT_EQ(n_type_map[2].size(), 10);
|
|
||||||
EXPECT_EQ(n_type_map[1].size(), 10);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
|
TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
|
||||||
std::string path = "data/mindrecord/testGraphData/testdata";
|
std::string path = "data/mindrecord/testGraphData/testdata";
|
||||||
Graph graph(path, 1);
|
GraphDataImpl graph(path, 1);
|
||||||
Status s = graph.Init();
|
Status s = graph.Init();
|
||||||
EXPECT_TRUE(s.IsOk());
|
EXPECT_TRUE(s.IsOk());
|
||||||
|
|
||||||
|
@ -103,7 +83,7 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
|
||||||
|
|
||||||
TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) {
|
TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) {
|
||||||
std::string path = "data/mindrecord/testGraphData/testdata";
|
std::string path = "data/mindrecord/testGraphData/testdata";
|
||||||
Graph graph(path, 1);
|
GraphDataImpl graph(path, 1);
|
||||||
Status s = graph.Init();
|
Status s = graph.Init();
|
||||||
EXPECT_TRUE(s.IsOk());
|
EXPECT_TRUE(s.IsOk());
|
||||||
|
|
||||||
|
@ -194,7 +174,7 @@ TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) {
|
||||||
|
|
||||||
TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) {
|
TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) {
|
||||||
std::string path = "data/mindrecord/testGraphData/testdata";
|
std::string path = "data/mindrecord/testGraphData/testdata";
|
||||||
Graph graph(path, 1);
|
GraphDataImpl graph(path, 1);
|
||||||
Status s = graph.Init();
|
Status s = graph.Init();
|
||||||
EXPECT_TRUE(s.IsOk());
|
EXPECT_TRUE(s.IsOk());
|
||||||
|
|
||||||
|
@ -237,7 +217,7 @@ TEST_F(MindDataTestGNNGraph, TestGetNegSampledNeighbors) {
|
||||||
|
|
||||||
TEST_F(MindDataTestGNNGraph, TestRandomWalk) {
|
TEST_F(MindDataTestGNNGraph, TestRandomWalk) {
|
||||||
std::string path = "data/mindrecord/testGraphData/sns";
|
std::string path = "data/mindrecord/testGraphData/sns";
|
||||||
Graph graph(path, 1);
|
GraphDataImpl graph(path, 1);
|
||||||
Status s = graph.Init();
|
Status s = graph.Init();
|
||||||
EXPECT_TRUE(s.IsOk());
|
EXPECT_TRUE(s.IsOk());
|
||||||
|
|
||||||
|
@ -263,7 +243,7 @@ TEST_F(MindDataTestGNNGraph, TestRandomWalk) {
|
||||||
|
|
||||||
TEST_F(MindDataTestGNNGraph, TestRandomWalkDefaults) {
|
TEST_F(MindDataTestGNNGraph, TestRandomWalkDefaults) {
|
||||||
std::string path = "data/mindrecord/testGraphData/sns";
|
std::string path = "data/mindrecord/testGraphData/sns";
|
||||||
Graph graph(path, 1);
|
GraphDataImpl graph(path, 1);
|
||||||
Status s = graph.Init();
|
Status s = graph.Init();
|
||||||
EXPECT_TRUE(s.IsOk());
|
EXPECT_TRUE(s.IsOk());
|
||||||
|
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,125 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from multiprocessing import Process
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
from mindspore import log as logger
|
||||||
|
|
||||||
|
DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
|
||||||
|
|
||||||
|
|
||||||
|
def graphdata_startserver():
|
||||||
|
"""
|
||||||
|
start graphdata server
|
||||||
|
"""
|
||||||
|
logger.info('test start server.\n')
|
||||||
|
ds.GraphData(DATASET_FILE, 1, 'server')
|
||||||
|
|
||||||
|
|
||||||
|
class RandomBatchedSampler(ds.Sampler):
|
||||||
|
# RandomBatchedSampler generate random sequence without replacement in a batched manner
|
||||||
|
def __init__(self, index_range, num_edges_per_sample):
|
||||||
|
super().__init__()
|
||||||
|
self.index_range = index_range
|
||||||
|
self.num_edges_per_sample = num_edges_per_sample
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
indices = [i+1 for i in range(self.index_range)]
|
||||||
|
# Reset random seed here if necessary
|
||||||
|
# random.seed(0)
|
||||||
|
random.shuffle(indices)
|
||||||
|
for i in range(0, self.index_range, self.num_edges_per_sample):
|
||||||
|
# Drop reminder
|
||||||
|
if i + self.num_edges_per_sample <= self.index_range:
|
||||||
|
yield indices[i: i + self.num_edges_per_sample]
|
||||||
|
|
||||||
|
|
||||||
|
class GNNGraphDataset():
|
||||||
|
def __init__(self, g, batch_num):
|
||||||
|
self.g = g
|
||||||
|
self.batch_num = batch_num
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
# Total sample size of GNN dataset
|
||||||
|
# In this case, the size should be total_num_edges/num_edges_per_sample
|
||||||
|
return self.g.graph_info()['edge_num'][0] // self.batch_num
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
# index will be a list of indices yielded from RandomBatchedSampler
|
||||||
|
# Fetch edges/nodes/samples/features based on indices
|
||||||
|
nodes = self.g.get_nodes_from_edges(index.astype(np.int32))
|
||||||
|
nodes = nodes[:, 0]
|
||||||
|
neg_nodes = self.g.get_neg_sampled_neighbors(
|
||||||
|
node_list=nodes, neg_neighbor_num=3, neg_neighbor_type=1)
|
||||||
|
nodes_neighbors = self.g.get_sampled_neighbors(node_list=nodes, neighbor_nums=[
|
||||||
|
2, 2], neighbor_types=[2, 1])
|
||||||
|
neg_nodes_neighbors = self.g.get_sampled_neighbors(
|
||||||
|
node_list=neg_nodes[:, 1:].reshape(-1), neighbor_nums=[2, 2], neighbor_types=[2, 2])
|
||||||
|
nodes_neighbors_features = self.g.get_node_feature(
|
||||||
|
node_list=nodes_neighbors, feature_types=[2, 3])
|
||||||
|
neg_neighbors_features = self.g.get_node_feature(
|
||||||
|
node_list=neg_nodes_neighbors, feature_types=[2, 3])
|
||||||
|
return nodes_neighbors, neg_nodes_neighbors, nodes_neighbors_features[0], neg_neighbors_features[1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_graphdata_distributed():
|
||||||
|
"""
|
||||||
|
Test distributed
|
||||||
|
"""
|
||||||
|
logger.info('test distributed.\n')
|
||||||
|
|
||||||
|
p1 = Process(target=graphdata_startserver)
|
||||||
|
p1.start()
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
g = ds.GraphData(DATASET_FILE, 1, 'client')
|
||||||
|
nodes = g.get_all_nodes(1)
|
||||||
|
assert nodes.tolist() == [101, 102, 103, 104, 105, 106, 107, 108, 109, 110]
|
||||||
|
row_tensor = g.get_node_feature(nodes.tolist(), [1, 2, 3])
|
||||||
|
assert row_tensor[0].tolist() == [[0, 1, 0, 0, 0], [1, 0, 0, 0, 1], [0, 0, 1, 1, 0], [0, 0, 0, 0, 0],
|
||||||
|
[1, 1, 0, 1, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0], [0, 0, 0, 1, 1],
|
||||||
|
[0, 1, 1, 0, 0], [0, 1, 0, 1, 0]]
|
||||||
|
assert row_tensor[2].tolist() == [1, 2, 3, 1, 4, 3, 5, 3, 5, 4]
|
||||||
|
|
||||||
|
edges = g.get_all_edges(0)
|
||||||
|
assert edges.tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
|
||||||
|
21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40]
|
||||||
|
features = g.get_edge_feature(edges, [1, 2])
|
||||||
|
assert features[0].tolist() == [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
|
||||||
|
0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0]
|
||||||
|
|
||||||
|
batch_num = 2
|
||||||
|
edge_num = g.graph_info()['edge_num'][0]
|
||||||
|
out_column_names = ["neighbors", "neg_neighbors", "neighbors_features", "neg_neighbors_features"]
|
||||||
|
dataset = ds.GeneratorDataset(source=GNNGraphDataset(g, batch_num), column_names=out_column_names,
|
||||||
|
sampler=RandomBatchedSampler(edge_num, batch_num), num_parallel_workers=4,
|
||||||
|
python_multiprocessing=False)
|
||||||
|
dataset = dataset.repeat(2)
|
||||||
|
itr = dataset.create_dict_iterator()
|
||||||
|
i = 0
|
||||||
|
for data in itr:
|
||||||
|
assert data['neighbors'].shape == (2, 7)
|
||||||
|
assert data['neg_neighbors'].shape == (6, 7)
|
||||||
|
assert data['neighbors_features'].shape == (2, 7)
|
||||||
|
assert data['neg_neighbors_features'].shape == (6, 7)
|
||||||
|
i += 1
|
||||||
|
assert i == 40
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_graphdata_distributed()
|
Loading…
Reference in New Issue