!16509 GNN dataset new feature for invalid edge or node id case

From: @lizhenglong1992
Reviewed-by: @liucunwei,@heleiwang
Signed-off-by: @liucunwei
This commit is contained in:
mindspore-ci-bot 2021-05-19 09:20:24 +08:00 committed by Gitee
commit 1cb2d31642
6 changed files with 56 additions and 24 deletions

View File

@ -141,10 +141,10 @@ Status GraphDataImpl::GetEdgesFromNodes(const std::vector<std::pair<NodeIdType,
std::shared_ptr<Node> src_node;
RETURN_IF_NOT_OK(GetNodeByNodeId(node_id.first, &src_node));
EdgeIdType *edge_id = nullptr;
EdgeIdType edge_id;
src_node->GetEdgeByAdjNodeId(node_id.second, &edge_id);
std::vector<EdgeIdType> connection_edge = {*edge_id};
std::vector<EdgeIdType> connection_edge = {edge_id};
edge_list.emplace_back(std::move(connection_edge));
}
@ -365,8 +365,8 @@ Status GraphDataImpl::GetNodeFeature(const std::shared_ptr<Tensor> &nodes,
feature = default_feature;
} else {
std::shared_ptr<Node> node;
RETURN_IF_NOT_OK(GetNodeByNodeId(*node_itr, &node));
if (!node->GetFeatures(f_type, &feature).IsOk()) {
if (!GetNodeByNodeId(*node_itr, &node).IsOk() || !node->GetFeatures(f_type, &feature).IsOk()) {
feature = default_feature;
}
}
@ -449,9 +449,9 @@ Status GraphDataImpl::GetEdgeFeature(const std::shared_ptr<Tensor> &edges,
dsize_t index = 0;
for (auto edge_itr = edges->begin<EdgeIdType>(); edge_itr != edges->end<EdgeIdType>(); ++edge_itr) {
std::shared_ptr<Edge> edge;
RETURN_IF_NOT_OK(GetEdgeByEdgeId(*edge_itr, &edge));
std::shared_ptr<Feature> feature;
if (!edge->GetFeatures(f_type, &feature).IsOk()) {
if (!GetEdgeByEdgeId(*edge_itr, &edge).IsOk() || !edge->GetFeatures(f_type, &feature).IsOk()) {
feature = default_feature;
}
RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, feature->Value()));

View File

@ -138,13 +138,13 @@ Status LocalNode::AddAdjacent(const std::shared_ptr<Node> &node, const std::shar
return Status::OK();
}
Status LocalNode::GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType **out_edge_id) {
Status LocalNode::GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType *out_edge_id) {
auto itr = adjacent_nodes_.find(adj_node_id);
if (itr != adjacent_nodes_.end()) {
(*out_edge_id) = &(itr->second);
(*out_edge_id) = itr->second;
} else {
(*out_edge_id) = new EdgeIdType(-1);
(*out_edge_id) = -1;
MS_LOG(WARNING) << "Number " << adj_node_id << " node is not adjacent to number " << this->id() << " node.";
}

View File

@ -75,7 +75,7 @@ class LocalNode : public Node {
// @param NodeIdType - The id of adjacent node to be processed
// @param std::shared_ptr<EdgeIdType> - The id of relative connecting edge
// @return Status - The status code that indicate the result of function execution
Status GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType **out_edge_id) override;
Status GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType *out_edge_id) override;
// Update feature of node
// @param std::shared_ptr<Feature> feature -

View File

@ -91,7 +91,7 @@ class Node {
// @param NodeIdType - The id of adjacent node to be processed
// @param std::shared_ptr<EdgeIdType> - The id of relative connecting edge
// @return Status - The status code that indicate the result of function execution
virtual Status GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType **out_edge_id) = 0;
virtual Status GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType *out_edge_id) = 0;
// Update feature of node
// @param std::shared_ptr<Feature> feature -

View File

@ -1,17 +1,17 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#Licensed under the Apache License, Version 2.0(the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#http: // www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
#== == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == ==
"""
graphdata.py supports loading graph dataset for GNN network training,
and provides operations related to graph data.
@ -175,7 +175,7 @@ class GraphData:
numpy.ndarray, array of edgs ID.
Examples:
>>> edges = graph_dataset.get_edges_from_nodes([(1, 3), (5, 2)])
>>> edges = graph_dataset.get_edges_from_nodes(node_list=[(101, 201), (103, 207)])
Raises:
TypeError: If `edge_list` is not list or ndarray.

View File

@ -241,6 +241,36 @@ def test_graphdata_getedgefeature():
assert features[1].shape == (40,)
def test_graphdata_getedgefeature_invalidcase():
"""
Test get edge feature with invalid edge id, 0 should be returned for those invalid edge id in correct index
"""
logger.info('test get_edge_feature.\n')
g = ds.GraphData(DATASET_FILE)
edges = g.get_all_edges(0)
edges[-6] = -1
features = g.get_edge_feature(edges, [1, 2])
assert features[0].shape == (40,)
assert features[1].shape == (40,)
assert features[0][-6] == 0
assert features[1][-6] == 0.
def test_graphdata_getnodefeature_invalidcase():
"""
Test get node feature with invalid node id, 0 should be returned for those invalid node id in correct index
"""
logger.info('test get_node_feature.\n')
g = ds.GraphData(DATASET_FILE)
nodes = g.get_all_nodes(node_type=1)
nodes[5] = -1
features = g.get_node_feature(node_list=nodes, feature_types=[2, 3])
assert features[0].shape == (10,)
assert features[1].shape == (10,)
assert features[0][5] == 0.
assert features[1][5] == 0
def test_graphdata_getedgesfromnodes():
"""
Test get edges from nodes
@ -249,7 +279,7 @@ def test_graphdata_getedgesfromnodes():
g = ds.GraphData(DATASET_FILE)
nodes_pair_list = [(101, 201), (103, 207), (204, 105), (108, 208), (110, 210), (210, 110)]
edges = g.get_edges_from_nodes(nodes_pair_list)
edges = g.get_edges_from_nodes(node_list=nodes_pair_list)
assert edges.tolist() == [1, 9, 31, 17, 20, 40]
@ -264,3 +294,5 @@ if __name__ == '__main__':
test_graphdata_randomwalk()
test_graphdata_getedgefeature()
test_graphdata_getedgesfromnodes()
test_graphdata_getnodefeature_invalidcase()
test_graphdata_getedgefeature_invalidcase()