forked from mindspore-Ecosystem/mindspore
!16509 GNN dataset new feature for invalid edge or node id case
From: @lizhenglong1992 Reviewed-by: @liucunwei,@heleiwang Signed-off-by: @liucunwei
This commit is contained in:
commit
1cb2d31642
|
@ -141,10 +141,10 @@ Status GraphDataImpl::GetEdgesFromNodes(const std::vector<std::pair<NodeIdType,
|
|||
std::shared_ptr<Node> src_node;
|
||||
RETURN_IF_NOT_OK(GetNodeByNodeId(node_id.first, &src_node));
|
||||
|
||||
EdgeIdType *edge_id = nullptr;
|
||||
EdgeIdType edge_id;
|
||||
src_node->GetEdgeByAdjNodeId(node_id.second, &edge_id);
|
||||
|
||||
std::vector<EdgeIdType> connection_edge = {*edge_id};
|
||||
std::vector<EdgeIdType> connection_edge = {edge_id};
|
||||
edge_list.emplace_back(std::move(connection_edge));
|
||||
}
|
||||
|
||||
|
@ -365,8 +365,8 @@ Status GraphDataImpl::GetNodeFeature(const std::shared_ptr<Tensor> &nodes,
|
|||
feature = default_feature;
|
||||
} else {
|
||||
std::shared_ptr<Node> node;
|
||||
RETURN_IF_NOT_OK(GetNodeByNodeId(*node_itr, &node));
|
||||
if (!node->GetFeatures(f_type, &feature).IsOk()) {
|
||||
|
||||
if (!GetNodeByNodeId(*node_itr, &node).IsOk() || !node->GetFeatures(f_type, &feature).IsOk()) {
|
||||
feature = default_feature;
|
||||
}
|
||||
}
|
||||
|
@ -449,9 +449,9 @@ Status GraphDataImpl::GetEdgeFeature(const std::shared_ptr<Tensor> &edges,
|
|||
dsize_t index = 0;
|
||||
for (auto edge_itr = edges->begin<EdgeIdType>(); edge_itr != edges->end<EdgeIdType>(); ++edge_itr) {
|
||||
std::shared_ptr<Edge> edge;
|
||||
RETURN_IF_NOT_OK(GetEdgeByEdgeId(*edge_itr, &edge));
|
||||
std::shared_ptr<Feature> feature;
|
||||
if (!edge->GetFeatures(f_type, &feature).IsOk()) {
|
||||
|
||||
if (!GetEdgeByEdgeId(*edge_itr, &edge).IsOk() || !edge->GetFeatures(f_type, &feature).IsOk()) {
|
||||
feature = default_feature;
|
||||
}
|
||||
RETURN_IF_NOT_OK(fea_tensor->InsertTensor({index}, feature->Value()));
|
||||
|
|
|
@ -138,13 +138,13 @@ Status LocalNode::AddAdjacent(const std::shared_ptr<Node> &node, const std::shar
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LocalNode::GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType **out_edge_id) {
|
||||
Status LocalNode::GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType *out_edge_id) {
|
||||
auto itr = adjacent_nodes_.find(adj_node_id);
|
||||
|
||||
if (itr != adjacent_nodes_.end()) {
|
||||
(*out_edge_id) = &(itr->second);
|
||||
(*out_edge_id) = itr->second;
|
||||
} else {
|
||||
(*out_edge_id) = new EdgeIdType(-1);
|
||||
(*out_edge_id) = -1;
|
||||
MS_LOG(WARNING) << "Number " << adj_node_id << " node is not adjacent to number " << this->id() << " node.";
|
||||
}
|
||||
|
||||
|
|
|
@ -75,7 +75,7 @@ class LocalNode : public Node {
|
|||
// @param NodeIdType - The id of adjacent node to be processed
|
||||
// @param std::shared_ptr<EdgeIdType> - The id of relative connecting edge
|
||||
// @return Status - The status code that indicate the result of function execution
|
||||
Status GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType **out_edge_id) override;
|
||||
Status GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType *out_edge_id) override;
|
||||
|
||||
// Update feature of node
|
||||
// @param std::shared_ptr<Feature> feature -
|
||||
|
|
|
@ -91,7 +91,7 @@ class Node {
|
|||
// @param NodeIdType - The id of adjacent node to be processed
|
||||
// @param std::shared_ptr<EdgeIdType> - The id of relative connecting edge
|
||||
// @return Status - The status code that indicate the result of function execution
|
||||
virtual Status GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType **out_edge_id) = 0;
|
||||
virtual Status GetEdgeByAdjNodeId(const NodeIdType &adj_node_id, EdgeIdType *out_edge_id) = 0;
|
||||
|
||||
// Update feature of node
|
||||
// @param std::shared_ptr<Feature> feature -
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#Licensed under the Apache License, Version 2.0(the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#http: // www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
#== == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == == ==
|
||||
"""
|
||||
graphdata.py supports loading graph dataset for GNN network training,
|
||||
and provides operations related to graph data.
|
||||
|
@ -175,7 +175,7 @@ class GraphData:
|
|||
numpy.ndarray, array of edgs ID.
|
||||
|
||||
Examples:
|
||||
>>> edges = graph_dataset.get_edges_from_nodes([(1, 3), (5, 2)])
|
||||
>>> edges = graph_dataset.get_edges_from_nodes(node_list=[(101, 201), (103, 207)])
|
||||
|
||||
Raises:
|
||||
TypeError: If `edge_list` is not list or ndarray.
|
||||
|
|
|
@ -241,6 +241,36 @@ def test_graphdata_getedgefeature():
|
|||
assert features[1].shape == (40,)
|
||||
|
||||
|
||||
def test_graphdata_getedgefeature_invalidcase():
|
||||
"""
|
||||
Test get edge feature with invalid edge id, 0 should be returned for those invalid edge id in correct index
|
||||
"""
|
||||
logger.info('test get_edge_feature.\n')
|
||||
g = ds.GraphData(DATASET_FILE)
|
||||
edges = g.get_all_edges(0)
|
||||
edges[-6] = -1
|
||||
features = g.get_edge_feature(edges, [1, 2])
|
||||
assert features[0].shape == (40,)
|
||||
assert features[1].shape == (40,)
|
||||
assert features[0][-6] == 0
|
||||
assert features[1][-6] == 0.
|
||||
|
||||
|
||||
def test_graphdata_getnodefeature_invalidcase():
|
||||
"""
|
||||
Test get node feature with invalid node id, 0 should be returned for those invalid node id in correct index
|
||||
"""
|
||||
logger.info('test get_node_feature.\n')
|
||||
g = ds.GraphData(DATASET_FILE)
|
||||
nodes = g.get_all_nodes(node_type=1)
|
||||
nodes[5] = -1
|
||||
features = g.get_node_feature(node_list=nodes, feature_types=[2, 3])
|
||||
assert features[0].shape == (10,)
|
||||
assert features[1].shape == (10,)
|
||||
assert features[0][5] == 0.
|
||||
assert features[1][5] == 0
|
||||
|
||||
|
||||
def test_graphdata_getedgesfromnodes():
|
||||
"""
|
||||
Test get edges from nodes
|
||||
|
@ -249,7 +279,7 @@ def test_graphdata_getedgesfromnodes():
|
|||
g = ds.GraphData(DATASET_FILE)
|
||||
|
||||
nodes_pair_list = [(101, 201), (103, 207), (204, 105), (108, 208), (110, 210), (210, 110)]
|
||||
edges = g.get_edges_from_nodes(nodes_pair_list)
|
||||
edges = g.get_edges_from_nodes(node_list=nodes_pair_list)
|
||||
assert edges.tolist() == [1, 9, 31, 17, 20, 40]
|
||||
|
||||
|
||||
|
@ -264,3 +294,5 @@ if __name__ == '__main__':
|
|||
test_graphdata_randomwalk()
|
||||
test_graphdata_getedgefeature()
|
||||
test_graphdata_getedgesfromnodes()
|
||||
test_graphdata_getnodefeature_invalidcase()
|
||||
test_graphdata_getedgefeature_invalidcase()
|
||||
|
|
Loading…
Reference in New Issue