diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/gnn/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/gnn/bindings.cc index 8f77a0091f8..28ec326316a 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/gnn/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/engine/gnn/bindings.cc @@ -63,9 +63,10 @@ PYBIND_REGISTER( return out; }) .def("get_all_neighbors", - [](gnn::GraphData &g, std::vector node_list, gnn::NodeType neighbor_type) { + [](gnn::GraphData &g, std::vector node_list, gnn::NodeType neighbor_type, + OutputFormat format) { std::shared_ptr out; - THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out)); + THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, format, &out)); return out; }) .def("get_sampled_neighbors", @@ -129,5 +130,13 @@ PYBIND_REGISTER(SamplingStrategy, 0, ([](const py::module *m) { .export_values(); })); +PYBIND_REGISTER(OutputFormat, 0, ([](const py::module *m) { + (void)py::enum_(*m, "OutputFormat", py::arithmetic()) + .value("DE_FORMAT_NORMAL", OutputFormat::kNormal) + .value("DE_FORMAT_COO", OutputFormat::kCoo) + .value("DE_FORMAT_CSR", OutputFormat::kCsr) + .export_values(); + })); + } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_graph_data.proto b/mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_graph_data.proto index 0c6a92d1c24..8701ef90ce6 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_graph_data.proto +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/gnn_graph_data.proto @@ -79,6 +79,7 @@ message GnnGraphDataRequestPb { GnnRandomWalkPb random_walk = 6; int32 strategy = 7; repeated IdPairPb node_pair = 8; + int32 format = 9; // output format for GET_ALL_NEIGHBORS function } message GnnGraphDataResponsePb { diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data.h index 965b45a30bc..7722717f48f 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data.h @@ -72,12 +72,13 @@ class GraphData { // All neighbors of the acquisition node. // @param std::vector node_list - List of nodes // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported + // @param OutputFormat format - The storage format for output, normal, COO or CSR are valid // @param std::shared_ptr *out - Returned neighbor's id. Because the number of neighbors at different nodes is // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors // is not enough, fill in tensor as -1. // @return Status The status code returned virtual Status GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, - std::shared_ptr *out) = 0; + const OutputFormat &format, std::shared_ptr *out) = 0; // Get sampled neighbors. // @param std::vector node_list - List of nodes diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.cc index 4839bd910ff..0415c11c354 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.cc @@ -140,7 +140,7 @@ Status GraphDataClient::GetEdgesFromNodes(const std::vector &node_list, NodeType neighbor_type, - std::shared_ptr *out) { + const OutputFormat &format, std::shared_ptr *out) { #if !defined(_WIN32) && !defined(_WIN64) GnnGraphDataRequestPb request; GnnGraphDataResponsePb response; @@ -149,6 +149,7 @@ Status GraphDataClient::GetAllNeighbors(const std::vector &node_list request.add_id(static_cast(node_id)); } request.add_type(static_cast(neighbor_type)); + request.set_format(static_cast(format)); RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out)); #endif return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.h index 7d76d7fec75..c05ad375c9e 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_client.h @@ -82,11 +82,12 @@ class GraphDataClient : public GraphData { // All neighbors of the acquisition node. // @param std::vector node_list - List of nodes // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported + // @param OutputFormat format - The storage format for output, normal, COO or CSR are valid // @param std::shared_ptr *out - Returned neighbor's id. Because the number of neighbors at different nodes is // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors // is not enough, fill in tensor as -1. // @return Status The status code returned - Status GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, + Status GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, const OutputFormat &format, std::shared_ptr *out) override; // Get sampled neighbors. diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.cc index 53f05604fe1..100cdb0c605 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.cc @@ -153,23 +153,68 @@ Status GraphDataImpl::GetEdgesFromNodes(const std::vector &node_list, NodeType neighbor_type, - std::shared_ptr *out) { + const OutputFormat &format, std::shared_ptr *out) { CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty."); RETURN_IF_NOT_OK(CheckNeighborType(neighbor_type)); std::vector> neighbors; - size_t max_neighbor_num = 0; + + size_t max_neighbor_num = 0; // Special parameter for normal format + size_t total_edge_num = 0; // Special parameter for coo and csr format + std::vector offset_table(node_list.size(), 0); // Special parameter for csr format + + // Collect information of adjacent table neighbors.resize(node_list.size()); for (size_t i = 0; i < node_list.size(); ++i) { std::shared_ptr node; RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[i], &node)); - RETURN_IF_NOT_OK(node->GetAllNeighbors(neighbor_type, &neighbors[i])); - max_neighbor_num = max_neighbor_num > neighbors[i].size() ? max_neighbor_num : neighbors[i].size(); + if (format == OutputFormat::kNormal) { + RETURN_IF_NOT_OK(node->GetAllNeighbors(neighbor_type, &neighbors[i])); + max_neighbor_num = max_neighbor_num > neighbors[i].size() ? max_neighbor_num : neighbors[i].size(); + } else if (format == OutputFormat::kCoo) { + RETURN_IF_NOT_OK(node->GetAllNeighbors(neighbor_type, &neighbors[i], true)); + total_edge_num += neighbors[i].size(); + } else { + RETURN_IF_NOT_OK(node->GetAllNeighbors(neighbor_type, &neighbors[i], true)); + total_edge_num += neighbors[i].size(); + if (i < node_list.size() - 1) { + offset_table[i + 1] = total_edge_num; + } + } } - RETURN_IF_NOT_OK(ComplementVector(&neighbors, max_neighbor_num, kDefaultNodeId)); - RETURN_IF_NOT_OK(CreateTensorByVector(neighbors, DataType(DataType::DE_INT32), out)); - + // By applying those information we obtained above, deal with the output with corresponding to + // output format + if (format == OutputFormat::kNormal) { + RETURN_IF_NOT_OK(ComplementVector(&neighbors, max_neighbor_num, kDefaultNodeId)); + RETURN_IF_NOT_OK(CreateTensorByVector(neighbors, DataType(DataType::DE_INT32), out)); + } else if (format == OutputFormat::kCoo) { + std::vector> coo_result; + coo_result.resize(total_edge_num); + size_t k = 0; + for (size_t i = 0; i < neighbors.size(); ++i) { + NodeIdType src = node_list[i]; + for (auto &dst : neighbors[i]) { + coo_result[k] = {src, dst}; + k++; + } + } + RETURN_IF_NOT_OK(CreateTensorByVector(coo_result, DataType(DataType::DE_INT32), out)); + } else { + std::vector> csr_result; + csr_result.resize(node_list.size() + total_edge_num); + for (size_t i = 0; i < offset_table.size(); ++i) { + csr_result[i] = {offset_table[i]}; + } + size_t edge_index = 0; + for (auto &neighbor : neighbors) { + for (auto &dst : neighbor) { + csr_result[node_list.size() + edge_index] = {dst}; + edge_index++; + } + } + RETURN_IF_NOT_OK(CreateTensorByVector(csr_result, DataType(DataType::DE_INT32), out)); + } return Status::OK(); } diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h index 2437a52a075..9838ca4ca39 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_impl.h @@ -76,11 +76,12 @@ class GraphDataImpl : public GraphData { // All neighbors of the acquisition node. // @param std::vector node_list - List of nodes // @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported + // @param OutputFormat format - The storage format for output, normal, COO or CSR are valid // @param std::shared_ptr *out - Returned neighbor's id. Because the number of neighbors at different nodes is // different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors // is not enough, fill in tensor as -1. // @return Status The status code returned - Status GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, + Status GetAllNeighbors(const std::vector &node_list, NodeType neighbor_type, const OutputFormat &format, std::shared_ptr *out) override; // Get sampled neighbors. diff --git a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.cc b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.cc index df79e37e50a..035c4bfac78 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/gnn/graph_data_service_impl.cc @@ -220,8 +220,10 @@ Status GraphDataServiceImpl::GetAllNeighbors(const GnnGraphDataRequestPb *reques node_list.resize(request->id().size()); std::transform(request->id().begin(), request->id().end(), node_list.begin(), [](const google::protobuf::int32 id) { return static_cast(id); }); + OutputFormat format = static_cast(request->format()); std::shared_ptr tensor; - RETURN_IF_NOT_OK(graph_data_impl_->GetAllNeighbors(node_list, static_cast(request->type()[0]), &tensor)); + RETURN_IF_NOT_OK( + graph_data_impl_->GetAllNeighbors(node_list, static_cast(request->type()[0]), format, &tensor)); TensorPb *result = response->add_result_data(); RETURN_IF_NOT_OK(TensorToPb(tensor, result)); return Status::OK(); diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/constants.h b/mindspore/ccsrc/minddata/dataset/include/dataset/constants.h index 4c28d943be7..3903fdea993 100644 --- a/mindspore/ccsrc/minddata/dataset/include/dataset/constants.h +++ b/mindspore/ccsrc/minddata/dataset/include/dataset/constants.h @@ -131,7 +131,14 @@ enum class SamplingStrategy { kEdgeWeight = 1 ///< Sampling with edge weight as probability. }; -// convenience functions for 32bit int bitmask. +/// \brief Possible values for output format in get all neighbors function of gnn dataset +enum class OutputFormat { + kNormal = 0, ///< Normal format> + kCoo = 1, ///< COO format> + kCsr = 2 ///< CSR format> +}; + +// convenience functions for 32bit int bitmask inline bool BitTest(uint32_t bits, uint32_t bitMask) { return (bits & bitMask) == bitMask; } inline void BitSet(uint32_t *bits, uint32_t bitMask) { *bits |= bitMask; } diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_jpeg_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_jpeg_op.cc index 7e00ab3052f..f3f827ae36a 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_jpeg_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_jpeg_op.cc @@ -104,7 +104,7 @@ Status DvppDecodeJpegOp::Compute(const std::shared_ptr &input, std::shar uint32_t decoded_heightStride = DecodeOut->heightStride; uint32_t decoded_width = DecodeOut->width; uint32_t decoded_widthStride = DecodeOut->widthStride; - // std::cout << "Decoded size: " << decoded_width << ", " << decoded_height << std::endl; + const TensorShape dvpp_shape({dvpp_length, 1, 1}); const DataType dvpp_data_type(DataType::DE_UINT8); mindspore::dataset::Tensor::CreateFromMemory(dvpp_shape, dvpp_data_type, ret_ptr, output); diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_resize_crop_jpeg_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_resize_crop_jpeg_op.cc index 08c0d1d278c..4fd7b1b9ef1 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_resize_crop_jpeg_op.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_resize_crop_jpeg_op.cc @@ -39,7 +39,7 @@ Status DvppDecodeResizeCropJpegOp::Compute(const std::shared_ptr & RETURN_STATUS_UNEXPECTED(error); } std::shared_ptr CropOut(processor_->Get_Croped_DeviceData()); - // std::cout << "Decoded size: " << decoded_width << ", " << decoded_height << std::endl; + const TensorShape dvpp_shape({1, 1, 1}); const DataType dvpp_data_type(DataType::DE_UINT8); mindspore::dataset::DeviceTensor::CreateEmpty(dvpp_shape, dvpp_data_type, output); diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/DvppCommon.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/DvppCommon.cc index 04513ace609..0661badbc44 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/DvppCommon.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/DvppCommon.cc @@ -1051,7 +1051,6 @@ APP_ERROR DvppCommon::CombineJpegdProcess(const RawData &imageInfo, acldvppPixel inputImage_ = std::make_shared(); inputImage_->format = format; APP_ERROR ret = - // GetJpegImageInfo(imageInfo.data.get(), imageInfo.lenOfByte, inputImage_->width, inputImage_->height, components); GetJpegImageInfo(imageInfo.data, imageInfo.lenOfByte, inputImage_->width, inputImage_->height, components); if (ret != APP_ERR_OK) { MS_LOG(ERROR) << "Failed to get input image info, ret = " << ret << "."; @@ -1153,7 +1152,6 @@ APP_ERROR DvppCommon::CombinePngdProcess(const RawData &imageInfo, acldvppPixelF inputImage_ = std::make_shared(); inputImage_->format = format; APP_ERROR ret = - // GetJpegImageInfo(imageInfo.data.get(), imageInfo.lenOfByte, inputImage_->width, inputImage_->height, components); GetPngImageInfo(imageInfo.data, imageInfo.lenOfByte, inputImage_->width, inputImage_->height, components); if (ret != APP_ERR_OK) { MS_LOG(ERROR) << "Failed to get input image info, ret = " << ret << "."; @@ -1268,8 +1266,6 @@ APP_ERROR DvppCommon::TransferImageH2D(const RawData &imageInfo, const std::shar } // Copy the image data from host to device - // ret = aclrtMemcpyAsync(inDevBuff, imageInfo.lenOfByte, imageInfo.data.get(), imageInfo.lenOfByte, - // ACL_MEMCPY_HOST_TO_DEVICE, dvppStream_); ret = aclrtMemcpyAsync(inDevBuff, imageInfo.lenOfByte, imageInfo.data, imageInfo.lenOfByte, ACL_MEMCPY_HOST_TO_DEVICE, dvppStream_); if (ret != APP_ERR_OK) { @@ -1319,7 +1315,7 @@ APP_ERROR DvppCommon::SinkImageH2D(const RawData &imageInfo, acldvppPixelFormat // Get the buffer size(On device) of decode output according to the input data and output format uint32_t outBufferSize; - // ret = GetJpegDecodeDataSize(imageInfo.data.get(), imageInfo.lenOfByte, format, outBuffSize); + ret = GetJpegDecodeDataSize(imageInfo.data, imageInfo.lenOfByte, format, outBufferSize); if (ret != APP_ERR_OK) { MS_LOG(ERROR) << "Failed to get size of decode output buffer, ret = " << ret << "."; @@ -1377,7 +1373,7 @@ APP_ERROR DvppCommon::SinkImageH2D(const RawData &imageInfo) { // Get the buffer size of decode output according to the input data and output format uint32_t outBuffSize; - // ret = GetJpegDecodeDataSize(imageInfo.data.get(), imageInfo.lenOfByte, format, outBuffSize); + ret = GetPngDecodeDataSize(imageInfo.data, imageInfo.lenOfByte, format, outBuffSize); if (ret != APP_ERR_OK) { MS_LOG(ERROR) << "Failed to get size of decode output buffer, ret = " << ret << "."; diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/MDAclProcess.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/MDAclProcess.cc index 31ba5ce6fc5..55886fcdf80 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/MDAclProcess.cc +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/MDAclProcess.cc @@ -172,7 +172,7 @@ APP_ERROR MDAclProcess::H2D_Sink(const std::shared_ptrSizeInBytes(); - // MS_LOG(INFO) << "Filesize on host is: " << filesize; + imageinfo.lenOfByte = filesize; unsigned char *buffer = const_cast(input->GetBuffer()); imageinfo.data = static_cast(buffer); @@ -188,7 +188,7 @@ APP_ERROR MDAclProcess::H2D_Sink(const std::shared_ptrGetInputImage(); - // std::cout << "[DEBUG]Sink data sunccessfully, Filesize on device is: " << deviceInputData->dataSize << std::endl; + const mindspore::dataset::DataType dvpp_data_type(mindspore::dataset::DataType::DE_UINT8); const mindspore::dataset::TensorShape dvpp_shape({1, 1, 1}); mindspore::dataset::DeviceTensor::CreateEmpty(dvpp_shape, dvpp_data_type, &device_input); diff --git a/mindspore/dataset/engine/__init__.py b/mindspore/dataset/engine/__init__.py index e1d40d360e7..dcf1a8cfea7 100644 --- a/mindspore/dataset/engine/__init__.py +++ b/mindspore/dataset/engine/__init__.py @@ -25,7 +25,7 @@ operations for users to preprocess data: shuffle, batch, repeat, map, and zip. from ..core import config from .cache_client import DatasetCache from .datasets import * -from .graphdata import GraphData, SamplingStrategy +from .graphdata import GraphData, SamplingStrategy, OutputFormat from .iterators import * from .samplers import * from .serializer_deserializer import compare, deserialize, serialize, show diff --git a/mindspore/dataset/engine/graphdata.py b/mindspore/dataset/engine/graphdata.py index 1afde76afc0..5ef5d434418 100644 --- a/mindspore/dataset/engine/graphdata.py +++ b/mindspore/dataset/engine/graphdata.py @@ -24,6 +24,7 @@ from mindspore._c_dataengine import GraphDataClient from mindspore._c_dataengine import GraphDataServer from mindspore._c_dataengine import Tensor from mindspore._c_dataengine import SamplingStrategy as Sampling +from mindspore._c_dataengine import OutputFormat as Format from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \ check_gnn_get_nodes_from_edges, check_gnn_get_edges_from_nodes, check_gnn_get_all_neighbors, \ @@ -42,6 +43,19 @@ DE_C_INTER_SAMPLING_STRATEGY = { } +class OutputFormat(IntEnum): + NORMAL = 0 + COO = 1 + CSR = 2 + + +DE_C_INTER_OUTPUT_FORMAT = { + OutputFormat.NORMAL: Format.DE_FORMAT_NORMAL, + OutputFormat.COO: Format.DE_FORMAT_COO, + OutputFormat.CSR: Format.DE_FORMAT_CSR, +} + + class GraphData: """ Reads the graph dataset used for GNN training from the shared file and database. @@ -185,20 +199,65 @@ class GraphData: return self._graph_data.get_edges_from_nodes(node_list).as_array() @check_gnn_get_all_neighbors - def get_all_neighbors(self, node_list, neighbor_type): + def get_all_neighbors(self, node_list, neighbor_type, output_format=OutputFormat.NORMAL): """ Get `neighbor_type` neighbors of the nodes in `node_list`. Args: node_list (Union[list, numpy.ndarray]): The given list of nodes. neighbor_type (int): Specify the type of neighbor. + output_format (OutputFormat, optional): Output storage format (default=OutputFormat.NORMAL) + It can be any of [OutputFormat.NORMAL, OutputFormat.COO, OutputFormat.CSR]. Returns: + For NORMAL format or COO format numpy.ndarray, array of neighbors. + If CSR format is specified, two numpy.ndarrays will return. + The first is offset table, the second is neighbors Examples: + We try to use the following example to illustrate the definition of these formats. 1 represents connected + between two nodes, and 0 represents not connected. + Raw Data: + 0 1 2 3 + 0 0 1 0 0 + 1 0 0 1 0 + 2 1 0 0 1 + 3 1 0 0 0 + + Normal format >>> nodes = graph_dataset.get_all_nodes(node_type=1) >>> neighbors = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2) + NORMAL: + dst_0 dst_1 + 0 1 -1 + 1 2 -1 + 2 0 3 + 3 1 -1 + + COO format + >>> nodes = graph_dataset.get_all_nodes(node_type=1) + >>> neighbors_coo = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2, + output_format=OutputFormat.COO) + COO: + src dst + 0 1 + 1 2 + 2 0 + 2 3 + 3 1 + + CSR format + >>> nodes = graph_dataset.get_all_nodes(node_type=1) + >>> offset_table, neighbors_csr = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2, + output_format=OutputFormat.CSR) + CSR: + offset table: dst table: + 0 1 + 1 2 + 2 0 + 4 3 + 1 Raises: TypeError: If `node_list` is not list or ndarray. @@ -206,7 +265,13 @@ class GraphData: """ if self._working_mode == 'server': raise Exception("This method is not supported when working mode is server.") - return self._graph_data.get_all_neighbors(node_list, neighbor_type).as_array() + result_list = self._graph_data.get_all_neighbors(node_list, neighbor_type, + DE_C_INTER_OUTPUT_FORMAT[output_format]).as_array() + if output_format == OutputFormat.CSR: + offset_table = result_list[:len(node_list)] + neighbor_table = result_list[len(node_list):] + return offset_table, neighbor_table + return result_list @check_gnn_get_sampled_neighbors def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types, strategy=SamplingStrategy.RANDOM): diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 4cb671abfed..9610c27e42d 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -1108,7 +1108,7 @@ def check_gnn_get_all_neighbors(method): @wraps(method) def new_method(self, *args, **kwargs): - [node_list, neighbour_type], _ = parse_user_args(method, *args, **kwargs) + [node_list, neighbour_type, _], _ = parse_user_args(method, *args, **kwargs) check_gnn_list_or_ndarray(node_list, 'node_list') type_check(neighbour_type, (int,), "neighbour_type") diff --git a/tests/ut/cpp/dataset/gnn_graph_test.cc b/tests/ut/cpp/dataset/gnn_graph_test.cc index 81990b972ba..03ff29ca58f 100644 --- a/tests/ut/cpp/dataset/gnn_graph_test.cc +++ b/tests/ut/cpp/dataset/gnn_graph_test.cc @@ -132,7 +132,7 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { } } std::shared_ptr neighbors; - s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], &neighbors); + s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], OutputFormat::kNormal, &neighbors); EXPECT_TRUE(s.IsOk()); EXPECT_TRUE(neighbors->shape().ToString() == "<10,6>"); TensorRow features; @@ -151,6 +151,47 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) { EXPECT_TRUE(features[2]->ToString() == "Tensor (shape: <10>, Type: int32)\n[1,2,3,1,4,3,5,3,5,4]"); } +TEST_F(MindDataTestGNNGraph, TestGetAllNeighborsSpecialFormat) { + std::string path = "data/mindrecord/testGraphData/testdata"; + GraphDataImpl graph(path, 1); + Status s = graph.Init(); + EXPECT_TRUE(s.IsOk()); + + MetaInfo meta_info; + s = graph.GetMetaInfo(&meta_info); + EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(meta_info.node_type.size() == 2); + + std::shared_ptr nodes; + s = graph.GetAllNodes(meta_info.node_type[0], &nodes); + EXPECT_TRUE(s.IsOk()); + std::vector node_list; + for (auto itr = nodes->begin(); itr != nodes->end(); ++itr) { + node_list.push_back(*itr); + if (node_list.size() >= 10) { + break; + } + } + // Check COO format + std::shared_ptr neighbors_coo; + s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], OutputFormat::kCoo, &neighbors_coo); + EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(neighbors_coo->shape().ToString() == "<20,2>"); + EXPECT_TRUE(neighbors_coo->ToString() == + "Tensor (shape: <20,2>, Type: int32)\n" + "[[101,201],[101,205],[101,206],[102,201],[102,202],[103,203],[103,205],[103,206],[103,207],[103,208]," + "[105,204],[106,202],[106,203],[107,201],[107,203],[107,207],[108,208],[109,210],[110,201],[110,210]]"); + // Check CSR format + std::shared_ptr neighbors_csr; + s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], OutputFormat::kCsr, &neighbors_csr); + EXPECT_TRUE(s.IsOk()); + EXPECT_TRUE(neighbors_csr->shape().ToString() == "<30>"); + EXPECT_TRUE( + neighbors_csr->ToString() == + "Tensor (shape: <30>, Type: int32)\n" + "[0,3,5,10,10,11,13,16,17,18,201,205,206,201,202,203,205,206,207,208,204,202,203,201,203,207,208,210,201,210]"); +} + TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) { std::string path = "data/mindrecord/testGraphData/testdata"; GraphDataImpl graph(path, 1); diff --git a/tests/ut/python/dataset/test_graphdata.py b/tests/ut/python/dataset/test_graphdata.py index 8d35885395c..723b4bb5baf 100644 --- a/tests/ut/python/dataset/test_graphdata.py +++ b/tests/ut/python/dataset/test_graphdata.py @@ -18,6 +18,7 @@ import numpy as np import mindspore.dataset as ds from mindspore import log as logger from mindspore.dataset.engine import SamplingStrategy +from mindspore.dataset.engine import OutputFormat DATASET_FILE = "../data/mindrecord/testGraphData/testdata" SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns" @@ -37,6 +38,23 @@ def test_graphdata_getfullneighbor(): assert row_tensor[0].shape == (10, 6) +def test_graphdata_getallneighbors_special_format(): + """ + Test get all neighbors with special format + """ + logger.info('test get all neighbors with special format.\n') + g = ds.GraphData(DATASET_FILE, 2) + nodes = g.get_all_nodes(1) + assert len(nodes) == 10 + + neighbor_coo = g.get_all_neighbors(nodes, 2, OutputFormat.COO) + assert neighbor_coo.shape == (20, 2) + + offset_table, neighbor_csr = g.get_all_neighbors(nodes, 2, OutputFormat.CSR) + assert offset_table.shape == (10,) + assert neighbor_csr.shape == (20,) + + def test_graphdata_getnodefeature_input_check(): """ Test get node feature input check diff --git a/tests/ut/python/dataset/test_graphdata_distributed.py b/tests/ut/python/dataset/test_graphdata_distributed.py index 22c8c6fac49..5b8ae99dd70 100644 --- a/tests/ut/python/dataset/test_graphdata_distributed.py +++ b/tests/ut/python/dataset/test_graphdata_distributed.py @@ -21,6 +21,7 @@ import numpy as np import mindspore.dataset as ds from mindspore import log as logger from mindspore.dataset.engine import SamplingStrategy +from mindspore.dataset.engine import OutputFormat DATASET_FILE = "../data/mindrecord/testGraphData/testdata" @@ -105,6 +106,14 @@ def test_graphdata_distributed(): [0, 1, 1, 0, 0], [0, 1, 0, 1, 0]] assert row_tensor[2].tolist() == [1, 2, 3, 1, 4, 3, 5, 3, 5, 4] + neighbor_normal = g.get_all_neighbors(nodes, 2, OutputFormat.NORMAL) + assert neighbor_normal.shape == (10, 6) + neighbor_coo = g.get_all_neighbors(nodes, 2, OutputFormat.COO) + assert neighbor_coo.shape == (20, 2) + offset_table, neighbor_csr = g.get_all_neighbors(nodes, 2, OutputFormat.CSR) + assert offset_table.shape == (10,) + assert neighbor_csr.shape == (20,) + edges = g.get_all_edges(0) assert edges.tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40]