Implement new gnn feature

This commit is contained in:
Zhenglong Li 2021-06-04 14:16:08 +08:00
parent a4e2ab0487
commit 2c7d1bf603
19 changed files with 227 additions and 30 deletions

View File

@ -63,9 +63,10 @@ PYBIND_REGISTER(
return out;
})
.def("get_all_neighbors",
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type) {
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type,
OutputFormat format) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out));
THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, format, &out));
return out;
})
.def("get_sampled_neighbors",
@ -129,5 +130,13 @@ PYBIND_REGISTER(SamplingStrategy, 0, ([](const py::module *m) {
.export_values();
}));
PYBIND_REGISTER(OutputFormat, 0, ([](const py::module *m) {
(void)py::enum_<OutputFormat>(*m, "OutputFormat", py::arithmetic())
.value("DE_FORMAT_NORMAL", OutputFormat::kNormal)
.value("DE_FORMAT_COO", OutputFormat::kCoo)
.value("DE_FORMAT_CSR", OutputFormat::kCsr)
.export_values();
}));
} // namespace dataset
} // namespace mindspore

View File

@ -79,6 +79,7 @@ message GnnGraphDataRequestPb {
GnnRandomWalkPb random_walk = 6;
int32 strategy = 7;
repeated IdPairPb node_pair = 8;
int32 format = 9; // output format for GET_ALL_NEIGHBORS function
}
message GnnGraphDataResponsePb {

View File

@ -72,12 +72,13 @@ class GraphData {
// All neighbors of the acquisition node.
// @param std::vector<NodeType> node_list - List of nodes
// @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported
// @param OutputFormat format - The storage format for output, normal, COO or CSR are valid
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
// is not enough, fill in tensor as -1.
// @return Status The status code returned
virtual Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
std::shared_ptr<Tensor> *out) = 0;
const OutputFormat &format, std::shared_ptr<Tensor> *out) = 0;
// Get sampled neighbors.
// @param std::vector<NodeType> node_list - List of nodes

View File

@ -140,7 +140,7 @@ Status GraphDataClient::GetEdgesFromNodes(const std::vector<std::pair<NodeIdType
}
Status GraphDataClient::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
std::shared_ptr<Tensor> *out) {
const OutputFormat &format, std::shared_ptr<Tensor> *out) {
#if !defined(_WIN32) && !defined(_WIN64)
GnnGraphDataRequestPb request;
GnnGraphDataResponsePb response;
@ -149,6 +149,7 @@ Status GraphDataClient::GetAllNeighbors(const std::vector<NodeIdType> &node_list
request.add_id(static_cast<google::protobuf::int32>(node_id));
}
request.add_type(static_cast<google::protobuf::int32>(neighbor_type));
request.set_format(static_cast<google::protobuf::int32>(format));
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
#endif
return Status::OK();

View File

@ -82,11 +82,12 @@ class GraphDataClient : public GraphData {
// All neighbors of the acquisition node.
// @param std::vector<NodeType> node_list - List of nodes
// @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported
// @param OutputFormat format - The storage format for output, normal, COO or CSR are valid
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
// is not enough, fill in tensor as -1.
// @return Status The status code returned
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, const OutputFormat &format,
std::shared_ptr<Tensor> *out) override;
// Get sampled neighbors.

View File

@ -153,23 +153,68 @@ Status GraphDataImpl::GetEdgesFromNodes(const std::vector<std::pair<NodeIdType,
}
Status GraphDataImpl::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
std::shared_ptr<Tensor> *out) {
const OutputFormat &format, std::shared_ptr<Tensor> *out) {
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
RETURN_IF_NOT_OK(CheckNeighborType(neighbor_type));
std::vector<std::vector<NodeIdType>> neighbors;
size_t max_neighbor_num = 0;
size_t max_neighbor_num = 0; // Special parameter for normal format
size_t total_edge_num = 0; // Special parameter for coo and csr format
std::vector<NodeIdType> offset_table(node_list.size(), 0); // Special parameter for csr format
// Collect information of adjacent table
neighbors.resize(node_list.size());
for (size_t i = 0; i < node_list.size(); ++i) {
std::shared_ptr<Node> node;
RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[i], &node));
if (format == OutputFormat::kNormal) {
RETURN_IF_NOT_OK(node->GetAllNeighbors(neighbor_type, &neighbors[i]));
max_neighbor_num = max_neighbor_num > neighbors[i].size() ? max_neighbor_num : neighbors[i].size();
} else if (format == OutputFormat::kCoo) {
RETURN_IF_NOT_OK(node->GetAllNeighbors(neighbor_type, &neighbors[i], true));
total_edge_num += neighbors[i].size();
} else {
RETURN_IF_NOT_OK(node->GetAllNeighbors(neighbor_type, &neighbors[i], true));
total_edge_num += neighbors[i].size();
if (i < node_list.size() - 1) {
offset_table[i + 1] = total_edge_num;
}
}
}
// By applying those information we obtained above, deal with the output with corresponding to
// output format
if (format == OutputFormat::kNormal) {
RETURN_IF_NOT_OK(ComplementVector<NodeIdType>(&neighbors, max_neighbor_num, kDefaultNodeId));
RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(neighbors, DataType(DataType::DE_INT32), out));
} else if (format == OutputFormat::kCoo) {
std::vector<std::vector<NodeIdType>> coo_result;
coo_result.resize(total_edge_num);
size_t k = 0;
for (size_t i = 0; i < neighbors.size(); ++i) {
NodeIdType src = node_list[i];
for (auto &dst : neighbors[i]) {
coo_result[k] = {src, dst};
k++;
}
}
RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(coo_result, DataType(DataType::DE_INT32), out));
} else {
std::vector<std::vector<NodeIdType>> csr_result;
csr_result.resize(node_list.size() + total_edge_num);
for (size_t i = 0; i < offset_table.size(); ++i) {
csr_result[i] = {offset_table[i]};
}
size_t edge_index = 0;
for (auto &neighbor : neighbors) {
for (auto &dst : neighbor) {
csr_result[node_list.size() + edge_index] = {dst};
edge_index++;
}
}
RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(csr_result, DataType(DataType::DE_INT32), out));
}
return Status::OK();
}

View File

@ -76,11 +76,12 @@ class GraphDataImpl : public GraphData {
// All neighbors of the acquisition node.
// @param std::vector<NodeType> node_list - List of nodes
// @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported
// @param OutputFormat format - The storage format for output, normal, COO or CSR are valid
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
// is not enough, fill in tensor as -1.
// @return Status The status code returned
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, const OutputFormat &format,
std::shared_ptr<Tensor> *out) override;
// Get sampled neighbors.

View File

@ -220,8 +220,10 @@ Status GraphDataServiceImpl::GetAllNeighbors(const GnnGraphDataRequestPb *reques
node_list.resize(request->id().size());
std::transform(request->id().begin(), request->id().end(), node_list.begin(),
[](const google::protobuf::int32 id) { return static_cast<NodeIdType>(id); });
OutputFormat format = static_cast<OutputFormat>(request->format());
std::shared_ptr<Tensor> tensor;
RETURN_IF_NOT_OK(graph_data_impl_->GetAllNeighbors(node_list, static_cast<NodeType>(request->type()[0]), &tensor));
RETURN_IF_NOT_OK(
graph_data_impl_->GetAllNeighbors(node_list, static_cast<NodeType>(request->type()[0]), format, &tensor));
TensorPb *result = response->add_result_data();
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
return Status::OK();

View File

@ -131,7 +131,14 @@ enum class SamplingStrategy {
kEdgeWeight = 1 ///< Sampling with edge weight as probability.
};
// convenience functions for 32bit int bitmask.
/// \brief Possible values for output format in get all neighbors function of gnn dataset
enum class OutputFormat {
kNormal = 0, ///< Normal format>
kCoo = 1, ///< COO format>
kCsr = 2 ///< CSR format>
};
// convenience functions for 32bit int bitmask
inline bool BitTest(uint32_t bits, uint32_t bitMask) { return (bits & bitMask) == bitMask; }
inline void BitSet(uint32_t *bits, uint32_t bitMask) { *bits |= bitMask; }

View File

@ -104,7 +104,7 @@ Status DvppDecodeJpegOp::Compute(const std::shared_ptr<Tensor> &input, std::shar
uint32_t decoded_heightStride = DecodeOut->heightStride;
uint32_t decoded_width = DecodeOut->width;
uint32_t decoded_widthStride = DecodeOut->widthStride;
// std::cout << "Decoded size: " << decoded_width << ", " << decoded_height << std::endl;
const TensorShape dvpp_shape({dvpp_length, 1, 1});
const DataType dvpp_data_type(DataType::DE_UINT8);
mindspore::dataset::Tensor::CreateFromMemory(dvpp_shape, dvpp_data_type, ret_ptr, output);

View File

@ -39,7 +39,7 @@ Status DvppDecodeResizeCropJpegOp::Compute(const std::shared_ptr<DeviceTensor> &
RETURN_STATUS_UNEXPECTED(error);
}
std::shared_ptr<DvppDataInfo> CropOut(processor_->Get_Croped_DeviceData());
// std::cout << "Decoded size: " << decoded_width << ", " << decoded_height << std::endl;
const TensorShape dvpp_shape({1, 1, 1});
const DataType dvpp_data_type(DataType::DE_UINT8);
mindspore::dataset::DeviceTensor::CreateEmpty(dvpp_shape, dvpp_data_type, output);

View File

@ -1051,7 +1051,6 @@ APP_ERROR DvppCommon::CombineJpegdProcess(const RawData &imageInfo, acldvppPixel
inputImage_ = std::make_shared<DvppDataInfo>();
inputImage_->format = format;
APP_ERROR ret =
// GetJpegImageInfo(imageInfo.data.get(), imageInfo.lenOfByte, inputImage_->width, inputImage_->height, components);
GetJpegImageInfo(imageInfo.data, imageInfo.lenOfByte, inputImage_->width, inputImage_->height, components);
if (ret != APP_ERR_OK) {
MS_LOG(ERROR) << "Failed to get input image info, ret = " << ret << ".";
@ -1153,7 +1152,6 @@ APP_ERROR DvppCommon::CombinePngdProcess(const RawData &imageInfo, acldvppPixelF
inputImage_ = std::make_shared<DvppDataInfo>();
inputImage_->format = format;
APP_ERROR ret =
// GetJpegImageInfo(imageInfo.data.get(), imageInfo.lenOfByte, inputImage_->width, inputImage_->height, components);
GetPngImageInfo(imageInfo.data, imageInfo.lenOfByte, inputImage_->width, inputImage_->height, components);
if (ret != APP_ERR_OK) {
MS_LOG(ERROR) << "Failed to get input image info, ret = " << ret << ".";
@ -1268,8 +1266,6 @@ APP_ERROR DvppCommon::TransferImageH2D(const RawData &imageInfo, const std::shar
}
// Copy the image data from host to device
// ret = aclrtMemcpyAsync(inDevBuff, imageInfo.lenOfByte, imageInfo.data.get(), imageInfo.lenOfByte,
// ACL_MEMCPY_HOST_TO_DEVICE, dvppStream_);
ret = aclrtMemcpyAsync(inDevBuff, imageInfo.lenOfByte, imageInfo.data, imageInfo.lenOfByte, ACL_MEMCPY_HOST_TO_DEVICE,
dvppStream_);
if (ret != APP_ERR_OK) {
@ -1319,7 +1315,7 @@ APP_ERROR DvppCommon::SinkImageH2D(const RawData &imageInfo, acldvppPixelFormat
// Get the buffer size(On device) of decode output according to the input data and output format
uint32_t outBufferSize;
// ret = GetJpegDecodeDataSize(imageInfo.data.get(), imageInfo.lenOfByte, format, outBuffSize);
ret = GetJpegDecodeDataSize(imageInfo.data, imageInfo.lenOfByte, format, outBufferSize);
if (ret != APP_ERR_OK) {
MS_LOG(ERROR) << "Failed to get size of decode output buffer, ret = " << ret << ".";
@ -1377,7 +1373,7 @@ APP_ERROR DvppCommon::SinkImageH2D(const RawData &imageInfo) {
// Get the buffer size of decode output according to the input data and output format
uint32_t outBuffSize;
// ret = GetJpegDecodeDataSize(imageInfo.data.get(), imageInfo.lenOfByte, format, outBuffSize);
ret = GetPngDecodeDataSize(imageInfo.data, imageInfo.lenOfByte, format, outBuffSize);
if (ret != APP_ERR_OK) {
MS_LOG(ERROR) << "Failed to get size of decode output buffer, ret = " << ret << ".";

View File

@ -172,7 +172,7 @@ APP_ERROR MDAclProcess::H2D_Sink(const std::shared_ptr<mindspore::dataset::Tenso
RawData imageinfo;
uint32_t filesize = input->SizeInBytes();
// MS_LOG(INFO) << "Filesize on host is: " << filesize;
imageinfo.lenOfByte = filesize;
unsigned char *buffer = const_cast<unsigned char *>(input->GetBuffer());
imageinfo.data = static_cast<void *>(buffer);
@ -188,7 +188,7 @@ APP_ERROR MDAclProcess::H2D_Sink(const std::shared_ptr<mindspore::dataset::Tenso
return ret;
}
auto deviceInputData = dvppCommon_->GetInputImage();
// std::cout << "[DEBUG]Sink data sunccessfully, Filesize on device is: " << deviceInputData->dataSize << std::endl;
const mindspore::dataset::DataType dvpp_data_type(mindspore::dataset::DataType::DE_UINT8);
const mindspore::dataset::TensorShape dvpp_shape({1, 1, 1});
mindspore::dataset::DeviceTensor::CreateEmpty(dvpp_shape, dvpp_data_type, &device_input);

View File

@ -25,7 +25,7 @@ operations for users to preprocess data: shuffle, batch, repeat, map, and zip.
from ..core import config
from .cache_client import DatasetCache
from .datasets import *
from .graphdata import GraphData, SamplingStrategy
from .graphdata import GraphData, SamplingStrategy, OutputFormat
from .iterators import *
from .samplers import *
from .serializer_deserializer import compare, deserialize, serialize, show

View File

@ -24,6 +24,7 @@ from mindspore._c_dataengine import GraphDataClient
from mindspore._c_dataengine import GraphDataServer
from mindspore._c_dataengine import Tensor
from mindspore._c_dataengine import SamplingStrategy as Sampling
from mindspore._c_dataengine import OutputFormat as Format
from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \
check_gnn_get_nodes_from_edges, check_gnn_get_edges_from_nodes, check_gnn_get_all_neighbors, \
@ -42,6 +43,19 @@ DE_C_INTER_SAMPLING_STRATEGY = {
}
class OutputFormat(IntEnum):
NORMAL = 0
COO = 1
CSR = 2
DE_C_INTER_OUTPUT_FORMAT = {
OutputFormat.NORMAL: Format.DE_FORMAT_NORMAL,
OutputFormat.COO: Format.DE_FORMAT_COO,
OutputFormat.CSR: Format.DE_FORMAT_CSR,
}
class GraphData:
"""
Reads the graph dataset used for GNN training from the shared file and database.
@ -185,20 +199,65 @@ class GraphData:
return self._graph_data.get_edges_from_nodes(node_list).as_array()
@check_gnn_get_all_neighbors
def get_all_neighbors(self, node_list, neighbor_type):
def get_all_neighbors(self, node_list, neighbor_type, output_format=OutputFormat.NORMAL):
"""
Get `neighbor_type` neighbors of the nodes in `node_list`.
Args:
node_list (Union[list, numpy.ndarray]): The given list of nodes.
neighbor_type (int): Specify the type of neighbor.
output_format (OutputFormat, optional): Output storage format (default=OutputFormat.NORMAL)
It can be any of [OutputFormat.NORMAL, OutputFormat.COO, OutputFormat.CSR].
Returns:
For NORMAL format or COO format
numpy.ndarray, array of neighbors.
If CSR format is specified, two numpy.ndarrays will return.
The first is offset table, the second is neighbors
Examples:
We try to use the following example to illustrate the definition of these formats. 1 represents connected
between two nodes, and 0 represents not connected.
Raw Data:
0 1 2 3
0 0 1 0 0
1 0 0 1 0
2 1 0 0 1
3 1 0 0 0
Normal format
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> neighbors = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2)
NORMAL:
dst_0 dst_1
0 1 -1
1 2 -1
2 0 3
3 1 -1
COO format
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> neighbors_coo = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2,
output_format=OutputFormat.COO)
COO:
src dst
0 1
1 2
2 0
2 3
3 1
CSR format
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
>>> offset_table, neighbors_csr = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2,
output_format=OutputFormat.CSR)
CSR:
offset table: dst table:
0 1
1 2
2 0
4 3
1
Raises:
TypeError: If `node_list` is not list or ndarray.
@ -206,7 +265,13 @@ class GraphData:
"""
if self._working_mode == 'server':
raise Exception("This method is not supported when working mode is server.")
return self._graph_data.get_all_neighbors(node_list, neighbor_type).as_array()
result_list = self._graph_data.get_all_neighbors(node_list, neighbor_type,
DE_C_INTER_OUTPUT_FORMAT[output_format]).as_array()
if output_format == OutputFormat.CSR:
offset_table = result_list[:len(node_list)]
neighbor_table = result_list[len(node_list):]
return offset_table, neighbor_table
return result_list
@check_gnn_get_sampled_neighbors
def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types, strategy=SamplingStrategy.RANDOM):

View File

@ -1108,7 +1108,7 @@ def check_gnn_get_all_neighbors(method):
@wraps(method)
def new_method(self, *args, **kwargs):
[node_list, neighbour_type], _ = parse_user_args(method, *args, **kwargs)
[node_list, neighbour_type, _], _ = parse_user_args(method, *args, **kwargs)
check_gnn_list_or_ndarray(node_list, 'node_list')
type_check(neighbour_type, (int,), "neighbour_type")

View File

@ -132,7 +132,7 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
}
}
std::shared_ptr<Tensor> neighbors;
s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], &neighbors);
s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], OutputFormat::kNormal, &neighbors);
EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(neighbors->shape().ToString() == "<10,6>");
TensorRow features;
@ -151,6 +151,47 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
EXPECT_TRUE(features[2]->ToString() == "Tensor (shape: <10>, Type: int32)\n[1,2,3,1,4,3,5,3,5,4]");
}
TEST_F(MindDataTestGNNGraph, TestGetAllNeighborsSpecialFormat) {
std::string path = "data/mindrecord/testGraphData/testdata";
GraphDataImpl graph(path, 1);
Status s = graph.Init();
EXPECT_TRUE(s.IsOk());
MetaInfo meta_info;
s = graph.GetMetaInfo(&meta_info);
EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(meta_info.node_type.size() == 2);
std::shared_ptr<Tensor> nodes;
s = graph.GetAllNodes(meta_info.node_type[0], &nodes);
EXPECT_TRUE(s.IsOk());
std::vector<NodeIdType> node_list;
for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
node_list.push_back(*itr);
if (node_list.size() >= 10) {
break;
}
}
// Check COO format
std::shared_ptr<Tensor> neighbors_coo;
s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], OutputFormat::kCoo, &neighbors_coo);
EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(neighbors_coo->shape().ToString() == "<20,2>");
EXPECT_TRUE(neighbors_coo->ToString() ==
"Tensor (shape: <20,2>, Type: int32)\n"
"[[101,201],[101,205],[101,206],[102,201],[102,202],[103,203],[103,205],[103,206],[103,207],[103,208],"
"[105,204],[106,202],[106,203],[107,201],[107,203],[107,207],[108,208],[109,210],[110,201],[110,210]]");
// Check CSR format
std::shared_ptr<Tensor> neighbors_csr;
s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], OutputFormat::kCsr, &neighbors_csr);
EXPECT_TRUE(s.IsOk());
EXPECT_TRUE(neighbors_csr->shape().ToString() == "<30>");
EXPECT_TRUE(
neighbors_csr->ToString() ==
"Tensor (shape: <30>, Type: int32)\n"
"[0,3,5,10,10,11,13,16,17,18,201,205,206,201,202,203,205,206,207,208,204,202,203,201,203,207,208,210,201,210]");
}
TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) {
std::string path = "data/mindrecord/testGraphData/testdata";
GraphDataImpl graph(path, 1);

View File

@ -18,6 +18,7 @@ import numpy as np
import mindspore.dataset as ds
from mindspore import log as logger
from mindspore.dataset.engine import SamplingStrategy
from mindspore.dataset.engine import OutputFormat
DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns"
@ -37,6 +38,23 @@ def test_graphdata_getfullneighbor():
assert row_tensor[0].shape == (10, 6)
def test_graphdata_getallneighbors_special_format():
"""
Test get all neighbors with special format
"""
logger.info('test get all neighbors with special format.\n')
g = ds.GraphData(DATASET_FILE, 2)
nodes = g.get_all_nodes(1)
assert len(nodes) == 10
neighbor_coo = g.get_all_neighbors(nodes, 2, OutputFormat.COO)
assert neighbor_coo.shape == (20, 2)
offset_table, neighbor_csr = g.get_all_neighbors(nodes, 2, OutputFormat.CSR)
assert offset_table.shape == (10,)
assert neighbor_csr.shape == (20,)
def test_graphdata_getnodefeature_input_check():
"""
Test get node feature input check

View File

@ -21,6 +21,7 @@ import numpy as np
import mindspore.dataset as ds
from mindspore import log as logger
from mindspore.dataset.engine import SamplingStrategy
from mindspore.dataset.engine import OutputFormat
DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
@ -105,6 +106,14 @@ def test_graphdata_distributed():
[0, 1, 1, 0, 0], [0, 1, 0, 1, 0]]
assert row_tensor[2].tolist() == [1, 2, 3, 1, 4, 3, 5, 3, 5, 4]
neighbor_normal = g.get_all_neighbors(nodes, 2, OutputFormat.NORMAL)
assert neighbor_normal.shape == (10, 6)
neighbor_coo = g.get_all_neighbors(nodes, 2, OutputFormat.COO)
assert neighbor_coo.shape == (20, 2)
offset_table, neighbor_csr = g.get_all_neighbors(nodes, 2, OutputFormat.CSR)
assert offset_table.shape == (10,)
assert neighbor_csr.shape == (20,)
edges = g.get_all_edges(0)
assert edges.tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40]