forked from mindspore-Ecosystem/mindspore
Implement new gnn feature
This commit is contained in:
parent
a4e2ab0487
commit
2c7d1bf603
|
@ -63,9 +63,10 @@ PYBIND_REGISTER(
|
|||
return out;
|
||||
})
|
||||
.def("get_all_neighbors",
|
||||
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type) {
|
||||
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type,
|
||||
OutputFormat format) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, &out));
|
||||
THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, format, &out));
|
||||
return out;
|
||||
})
|
||||
.def("get_sampled_neighbors",
|
||||
|
@ -129,5 +130,13 @@ PYBIND_REGISTER(SamplingStrategy, 0, ([](const py::module *m) {
|
|||
.export_values();
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(OutputFormat, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<OutputFormat>(*m, "OutputFormat", py::arithmetic())
|
||||
.value("DE_FORMAT_NORMAL", OutputFormat::kNormal)
|
||||
.value("DE_FORMAT_COO", OutputFormat::kCoo)
|
||||
.value("DE_FORMAT_CSR", OutputFormat::kCsr)
|
||||
.export_values();
|
||||
}));
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -79,6 +79,7 @@ message GnnGraphDataRequestPb {
|
|||
GnnRandomWalkPb random_walk = 6;
|
||||
int32 strategy = 7;
|
||||
repeated IdPairPb node_pair = 8;
|
||||
int32 format = 9; // output format for GET_ALL_NEIGHBORS function
|
||||
}
|
||||
|
||||
message GnnGraphDataResponsePb {
|
||||
|
|
|
@ -72,12 +72,13 @@ class GraphData {
|
|||
// All neighbors of the acquisition node.
|
||||
// @param std::vector<NodeType> node_list - List of nodes
|
||||
// @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported
|
||||
// @param OutputFormat format - The storage format for output, normal, COO or CSR are valid
|
||||
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
|
||||
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
|
||||
// is not enough, fill in tensor as -1.
|
||||
// @return Status The status code returned
|
||||
virtual Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
|
||||
std::shared_ptr<Tensor> *out) = 0;
|
||||
const OutputFormat &format, std::shared_ptr<Tensor> *out) = 0;
|
||||
|
||||
// Get sampled neighbors.
|
||||
// @param std::vector<NodeType> node_list - List of nodes
|
||||
|
|
|
@ -140,7 +140,7 @@ Status GraphDataClient::GetEdgesFromNodes(const std::vector<std::pair<NodeIdType
|
|||
}
|
||||
|
||||
Status GraphDataClient::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
|
||||
std::shared_ptr<Tensor> *out) {
|
||||
const OutputFormat &format, std::shared_ptr<Tensor> *out) {
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
GnnGraphDataRequestPb request;
|
||||
GnnGraphDataResponsePb response;
|
||||
|
@ -149,6 +149,7 @@ Status GraphDataClient::GetAllNeighbors(const std::vector<NodeIdType> &node_list
|
|||
request.add_id(static_cast<google::protobuf::int32>(node_id));
|
||||
}
|
||||
request.add_type(static_cast<google::protobuf::int32>(neighbor_type));
|
||||
request.set_format(static_cast<google::protobuf::int32>(format));
|
||||
RETURN_IF_NOT_OK(GetGraphDataTensor(request, &response, out));
|
||||
#endif
|
||||
return Status::OK();
|
||||
|
|
|
@ -82,11 +82,12 @@ class GraphDataClient : public GraphData {
|
|||
// All neighbors of the acquisition node.
|
||||
// @param std::vector<NodeType> node_list - List of nodes
|
||||
// @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported
|
||||
// @param OutputFormat format - The storage format for output, normal, COO or CSR are valid
|
||||
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
|
||||
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
|
||||
// is not enough, fill in tensor as -1.
|
||||
// @return Status The status code returned
|
||||
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
|
||||
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, const OutputFormat &format,
|
||||
std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
// Get sampled neighbors.
|
||||
|
|
|
@ -153,23 +153,68 @@ Status GraphDataImpl::GetEdgesFromNodes(const std::vector<std::pair<NodeIdType,
|
|||
}
|
||||
|
||||
Status GraphDataImpl::GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
|
||||
std::shared_ptr<Tensor> *out) {
|
||||
const OutputFormat &format, std::shared_ptr<Tensor> *out) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(!node_list.empty(), "Input node_list is empty.");
|
||||
RETURN_IF_NOT_OK(CheckNeighborType(neighbor_type));
|
||||
|
||||
std::vector<std::vector<NodeIdType>> neighbors;
|
||||
size_t max_neighbor_num = 0;
|
||||
|
||||
size_t max_neighbor_num = 0; // Special parameter for normal format
|
||||
size_t total_edge_num = 0; // Special parameter for coo and csr format
|
||||
std::vector<NodeIdType> offset_table(node_list.size(), 0); // Special parameter for csr format
|
||||
|
||||
// Collect information of adjacent table
|
||||
neighbors.resize(node_list.size());
|
||||
for (size_t i = 0; i < node_list.size(); ++i) {
|
||||
std::shared_ptr<Node> node;
|
||||
RETURN_IF_NOT_OK(GetNodeByNodeId(node_list[i], &node));
|
||||
if (format == OutputFormat::kNormal) {
|
||||
RETURN_IF_NOT_OK(node->GetAllNeighbors(neighbor_type, &neighbors[i]));
|
||||
max_neighbor_num = max_neighbor_num > neighbors[i].size() ? max_neighbor_num : neighbors[i].size();
|
||||
} else if (format == OutputFormat::kCoo) {
|
||||
RETURN_IF_NOT_OK(node->GetAllNeighbors(neighbor_type, &neighbors[i], true));
|
||||
total_edge_num += neighbors[i].size();
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(node->GetAllNeighbors(neighbor_type, &neighbors[i], true));
|
||||
total_edge_num += neighbors[i].size();
|
||||
if (i < node_list.size() - 1) {
|
||||
offset_table[i + 1] = total_edge_num;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// By applying those information we obtained above, deal with the output with corresponding to
|
||||
// output format
|
||||
if (format == OutputFormat::kNormal) {
|
||||
RETURN_IF_NOT_OK(ComplementVector<NodeIdType>(&neighbors, max_neighbor_num, kDefaultNodeId));
|
||||
RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(neighbors, DataType(DataType::DE_INT32), out));
|
||||
|
||||
} else if (format == OutputFormat::kCoo) {
|
||||
std::vector<std::vector<NodeIdType>> coo_result;
|
||||
coo_result.resize(total_edge_num);
|
||||
size_t k = 0;
|
||||
for (size_t i = 0; i < neighbors.size(); ++i) {
|
||||
NodeIdType src = node_list[i];
|
||||
for (auto &dst : neighbors[i]) {
|
||||
coo_result[k] = {src, dst};
|
||||
k++;
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(coo_result, DataType(DataType::DE_INT32), out));
|
||||
} else {
|
||||
std::vector<std::vector<NodeIdType>> csr_result;
|
||||
csr_result.resize(node_list.size() + total_edge_num);
|
||||
for (size_t i = 0; i < offset_table.size(); ++i) {
|
||||
csr_result[i] = {offset_table[i]};
|
||||
}
|
||||
size_t edge_index = 0;
|
||||
for (auto &neighbor : neighbors) {
|
||||
for (auto &dst : neighbor) {
|
||||
csr_result[node_list.size() + edge_index] = {dst};
|
||||
edge_index++;
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(csr_result, DataType(DataType::DE_INT32), out));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -76,11 +76,12 @@ class GraphDataImpl : public GraphData {
|
|||
// All neighbors of the acquisition node.
|
||||
// @param std::vector<NodeType> node_list - List of nodes
|
||||
// @param NodeType neighbor_type - The type of neighbor. If the type does not exist, an error will be reported
|
||||
// @param OutputFormat format - The storage format for output, normal, COO or CSR are valid
|
||||
// @param std::shared_ptr<Tensor> *out - Returned neighbor's id. Because the number of neighbors at different nodes is
|
||||
// different, the returned tensor is output according to the maximum number of neighbors. If the number of neighbors
|
||||
// is not enough, fill in tensor as -1.
|
||||
// @return Status The status code returned
|
||||
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type,
|
||||
Status GetAllNeighbors(const std::vector<NodeIdType> &node_list, NodeType neighbor_type, const OutputFormat &format,
|
||||
std::shared_ptr<Tensor> *out) override;
|
||||
|
||||
// Get sampled neighbors.
|
||||
|
|
|
@ -220,8 +220,10 @@ Status GraphDataServiceImpl::GetAllNeighbors(const GnnGraphDataRequestPb *reques
|
|||
node_list.resize(request->id().size());
|
||||
std::transform(request->id().begin(), request->id().end(), node_list.begin(),
|
||||
[](const google::protobuf::int32 id) { return static_cast<NodeIdType>(id); });
|
||||
OutputFormat format = static_cast<OutputFormat>(request->format());
|
||||
std::shared_ptr<Tensor> tensor;
|
||||
RETURN_IF_NOT_OK(graph_data_impl_->GetAllNeighbors(node_list, static_cast<NodeType>(request->type()[0]), &tensor));
|
||||
RETURN_IF_NOT_OK(
|
||||
graph_data_impl_->GetAllNeighbors(node_list, static_cast<NodeType>(request->type()[0]), format, &tensor));
|
||||
TensorPb *result = response->add_result_data();
|
||||
RETURN_IF_NOT_OK(TensorToPb(tensor, result));
|
||||
return Status::OK();
|
||||
|
|
|
@ -131,7 +131,14 @@ enum class SamplingStrategy {
|
|||
kEdgeWeight = 1 ///< Sampling with edge weight as probability.
|
||||
};
|
||||
|
||||
// convenience functions for 32bit int bitmask.
|
||||
/// \brief Possible values for output format in get all neighbors function of gnn dataset
|
||||
enum class OutputFormat {
|
||||
kNormal = 0, ///< Normal format>
|
||||
kCoo = 1, ///< COO format>
|
||||
kCsr = 2 ///< CSR format>
|
||||
};
|
||||
|
||||
// convenience functions for 32bit int bitmask
|
||||
inline bool BitTest(uint32_t bits, uint32_t bitMask) { return (bits & bitMask) == bitMask; }
|
||||
|
||||
inline void BitSet(uint32_t *bits, uint32_t bitMask) { *bits |= bitMask; }
|
||||
|
|
|
@ -104,7 +104,7 @@ Status DvppDecodeJpegOp::Compute(const std::shared_ptr<Tensor> &input, std::shar
|
|||
uint32_t decoded_heightStride = DecodeOut->heightStride;
|
||||
uint32_t decoded_width = DecodeOut->width;
|
||||
uint32_t decoded_widthStride = DecodeOut->widthStride;
|
||||
// std::cout << "Decoded size: " << decoded_width << ", " << decoded_height << std::endl;
|
||||
|
||||
const TensorShape dvpp_shape({dvpp_length, 1, 1});
|
||||
const DataType dvpp_data_type(DataType::DE_UINT8);
|
||||
mindspore::dataset::Tensor::CreateFromMemory(dvpp_shape, dvpp_data_type, ret_ptr, output);
|
||||
|
|
|
@ -39,7 +39,7 @@ Status DvppDecodeResizeCropJpegOp::Compute(const std::shared_ptr<DeviceTensor> &
|
|||
RETURN_STATUS_UNEXPECTED(error);
|
||||
}
|
||||
std::shared_ptr<DvppDataInfo> CropOut(processor_->Get_Croped_DeviceData());
|
||||
// std::cout << "Decoded size: " << decoded_width << ", " << decoded_height << std::endl;
|
||||
|
||||
const TensorShape dvpp_shape({1, 1, 1});
|
||||
const DataType dvpp_data_type(DataType::DE_UINT8);
|
||||
mindspore::dataset::DeviceTensor::CreateEmpty(dvpp_shape, dvpp_data_type, output);
|
||||
|
|
|
@ -1051,7 +1051,6 @@ APP_ERROR DvppCommon::CombineJpegdProcess(const RawData &imageInfo, acldvppPixel
|
|||
inputImage_ = std::make_shared<DvppDataInfo>();
|
||||
inputImage_->format = format;
|
||||
APP_ERROR ret =
|
||||
// GetJpegImageInfo(imageInfo.data.get(), imageInfo.lenOfByte, inputImage_->width, inputImage_->height, components);
|
||||
GetJpegImageInfo(imageInfo.data, imageInfo.lenOfByte, inputImage_->width, inputImage_->height, components);
|
||||
if (ret != APP_ERR_OK) {
|
||||
MS_LOG(ERROR) << "Failed to get input image info, ret = " << ret << ".";
|
||||
|
@ -1153,7 +1152,6 @@ APP_ERROR DvppCommon::CombinePngdProcess(const RawData &imageInfo, acldvppPixelF
|
|||
inputImage_ = std::make_shared<DvppDataInfo>();
|
||||
inputImage_->format = format;
|
||||
APP_ERROR ret =
|
||||
// GetJpegImageInfo(imageInfo.data.get(), imageInfo.lenOfByte, inputImage_->width, inputImage_->height, components);
|
||||
GetPngImageInfo(imageInfo.data, imageInfo.lenOfByte, inputImage_->width, inputImage_->height, components);
|
||||
if (ret != APP_ERR_OK) {
|
||||
MS_LOG(ERROR) << "Failed to get input image info, ret = " << ret << ".";
|
||||
|
@ -1268,8 +1266,6 @@ APP_ERROR DvppCommon::TransferImageH2D(const RawData &imageInfo, const std::shar
|
|||
}
|
||||
|
||||
// Copy the image data from host to device
|
||||
// ret = aclrtMemcpyAsync(inDevBuff, imageInfo.lenOfByte, imageInfo.data.get(), imageInfo.lenOfByte,
|
||||
// ACL_MEMCPY_HOST_TO_DEVICE, dvppStream_);
|
||||
ret = aclrtMemcpyAsync(inDevBuff, imageInfo.lenOfByte, imageInfo.data, imageInfo.lenOfByte, ACL_MEMCPY_HOST_TO_DEVICE,
|
||||
dvppStream_);
|
||||
if (ret != APP_ERR_OK) {
|
||||
|
@ -1319,7 +1315,7 @@ APP_ERROR DvppCommon::SinkImageH2D(const RawData &imageInfo, acldvppPixelFormat
|
|||
|
||||
// Get the buffer size(On device) of decode output according to the input data and output format
|
||||
uint32_t outBufferSize;
|
||||
// ret = GetJpegDecodeDataSize(imageInfo.data.get(), imageInfo.lenOfByte, format, outBuffSize);
|
||||
|
||||
ret = GetJpegDecodeDataSize(imageInfo.data, imageInfo.lenOfByte, format, outBufferSize);
|
||||
if (ret != APP_ERR_OK) {
|
||||
MS_LOG(ERROR) << "Failed to get size of decode output buffer, ret = " << ret << ".";
|
||||
|
@ -1377,7 +1373,7 @@ APP_ERROR DvppCommon::SinkImageH2D(const RawData &imageInfo) {
|
|||
|
||||
// Get the buffer size of decode output according to the input data and output format
|
||||
uint32_t outBuffSize;
|
||||
// ret = GetJpegDecodeDataSize(imageInfo.data.get(), imageInfo.lenOfByte, format, outBuffSize);
|
||||
|
||||
ret = GetPngDecodeDataSize(imageInfo.data, imageInfo.lenOfByte, format, outBuffSize);
|
||||
if (ret != APP_ERR_OK) {
|
||||
MS_LOG(ERROR) << "Failed to get size of decode output buffer, ret = " << ret << ".";
|
||||
|
|
|
@ -172,7 +172,7 @@ APP_ERROR MDAclProcess::H2D_Sink(const std::shared_ptr<mindspore::dataset::Tenso
|
|||
|
||||
RawData imageinfo;
|
||||
uint32_t filesize = input->SizeInBytes();
|
||||
// MS_LOG(INFO) << "Filesize on host is: " << filesize;
|
||||
|
||||
imageinfo.lenOfByte = filesize;
|
||||
unsigned char *buffer = const_cast<unsigned char *>(input->GetBuffer());
|
||||
imageinfo.data = static_cast<void *>(buffer);
|
||||
|
@ -188,7 +188,7 @@ APP_ERROR MDAclProcess::H2D_Sink(const std::shared_ptr<mindspore::dataset::Tenso
|
|||
return ret;
|
||||
}
|
||||
auto deviceInputData = dvppCommon_->GetInputImage();
|
||||
// std::cout << "[DEBUG]Sink data sunccessfully, Filesize on device is: " << deviceInputData->dataSize << std::endl;
|
||||
|
||||
const mindspore::dataset::DataType dvpp_data_type(mindspore::dataset::DataType::DE_UINT8);
|
||||
const mindspore::dataset::TensorShape dvpp_shape({1, 1, 1});
|
||||
mindspore::dataset::DeviceTensor::CreateEmpty(dvpp_shape, dvpp_data_type, &device_input);
|
||||
|
|
|
@ -25,7 +25,7 @@ operations for users to preprocess data: shuffle, batch, repeat, map, and zip.
|
|||
from ..core import config
|
||||
from .cache_client import DatasetCache
|
||||
from .datasets import *
|
||||
from .graphdata import GraphData, SamplingStrategy
|
||||
from .graphdata import GraphData, SamplingStrategy, OutputFormat
|
||||
from .iterators import *
|
||||
from .samplers import *
|
||||
from .serializer_deserializer import compare, deserialize, serialize, show
|
||||
|
|
|
@ -24,6 +24,7 @@ from mindspore._c_dataengine import GraphDataClient
|
|||
from mindspore._c_dataengine import GraphDataServer
|
||||
from mindspore._c_dataengine import Tensor
|
||||
from mindspore._c_dataengine import SamplingStrategy as Sampling
|
||||
from mindspore._c_dataengine import OutputFormat as Format
|
||||
|
||||
from .validators import check_gnn_graphdata, check_gnn_get_all_nodes, check_gnn_get_all_edges, \
|
||||
check_gnn_get_nodes_from_edges, check_gnn_get_edges_from_nodes, check_gnn_get_all_neighbors, \
|
||||
|
@ -42,6 +43,19 @@ DE_C_INTER_SAMPLING_STRATEGY = {
|
|||
}
|
||||
|
||||
|
||||
class OutputFormat(IntEnum):
|
||||
NORMAL = 0
|
||||
COO = 1
|
||||
CSR = 2
|
||||
|
||||
|
||||
DE_C_INTER_OUTPUT_FORMAT = {
|
||||
OutputFormat.NORMAL: Format.DE_FORMAT_NORMAL,
|
||||
OutputFormat.COO: Format.DE_FORMAT_COO,
|
||||
OutputFormat.CSR: Format.DE_FORMAT_CSR,
|
||||
}
|
||||
|
||||
|
||||
class GraphData:
|
||||
"""
|
||||
Reads the graph dataset used for GNN training from the shared file and database.
|
||||
|
@ -185,20 +199,65 @@ class GraphData:
|
|||
return self._graph_data.get_edges_from_nodes(node_list).as_array()
|
||||
|
||||
@check_gnn_get_all_neighbors
|
||||
def get_all_neighbors(self, node_list, neighbor_type):
|
||||
def get_all_neighbors(self, node_list, neighbor_type, output_format=OutputFormat.NORMAL):
|
||||
"""
|
||||
Get `neighbor_type` neighbors of the nodes in `node_list`.
|
||||
|
||||
Args:
|
||||
node_list (Union[list, numpy.ndarray]): The given list of nodes.
|
||||
neighbor_type (int): Specify the type of neighbor.
|
||||
output_format (OutputFormat, optional): Output storage format (default=OutputFormat.NORMAL)
|
||||
It can be any of [OutputFormat.NORMAL, OutputFormat.COO, OutputFormat.CSR].
|
||||
|
||||
Returns:
|
||||
For NORMAL format or COO format
|
||||
numpy.ndarray, array of neighbors.
|
||||
If CSR format is specified, two numpy.ndarrays will return.
|
||||
The first is offset table, the second is neighbors
|
||||
|
||||
Examples:
|
||||
We try to use the following example to illustrate the definition of these formats. 1 represents connected
|
||||
between two nodes, and 0 represents not connected.
|
||||
Raw Data:
|
||||
0 1 2 3
|
||||
0 0 1 0 0
|
||||
1 0 0 1 0
|
||||
2 1 0 0 1
|
||||
3 1 0 0 0
|
||||
|
||||
Normal format
|
||||
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
|
||||
>>> neighbors = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2)
|
||||
NORMAL:
|
||||
dst_0 dst_1
|
||||
0 1 -1
|
||||
1 2 -1
|
||||
2 0 3
|
||||
3 1 -1
|
||||
|
||||
COO format
|
||||
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
|
||||
>>> neighbors_coo = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2,
|
||||
output_format=OutputFormat.COO)
|
||||
COO:
|
||||
src dst
|
||||
0 1
|
||||
1 2
|
||||
2 0
|
||||
2 3
|
||||
3 1
|
||||
|
||||
CSR format
|
||||
>>> nodes = graph_dataset.get_all_nodes(node_type=1)
|
||||
>>> offset_table, neighbors_csr = graph_dataset.get_all_neighbors(node_list=nodes, neighbor_type=2,
|
||||
output_format=OutputFormat.CSR)
|
||||
CSR:
|
||||
offset table: dst table:
|
||||
0 1
|
||||
1 2
|
||||
2 0
|
||||
4 3
|
||||
1
|
||||
|
||||
Raises:
|
||||
TypeError: If `node_list` is not list or ndarray.
|
||||
|
@ -206,7 +265,13 @@ class GraphData:
|
|||
"""
|
||||
if self._working_mode == 'server':
|
||||
raise Exception("This method is not supported when working mode is server.")
|
||||
return self._graph_data.get_all_neighbors(node_list, neighbor_type).as_array()
|
||||
result_list = self._graph_data.get_all_neighbors(node_list, neighbor_type,
|
||||
DE_C_INTER_OUTPUT_FORMAT[output_format]).as_array()
|
||||
if output_format == OutputFormat.CSR:
|
||||
offset_table = result_list[:len(node_list)]
|
||||
neighbor_table = result_list[len(node_list):]
|
||||
return offset_table, neighbor_table
|
||||
return result_list
|
||||
|
||||
@check_gnn_get_sampled_neighbors
|
||||
def get_sampled_neighbors(self, node_list, neighbor_nums, neighbor_types, strategy=SamplingStrategy.RANDOM):
|
||||
|
|
|
@ -1108,7 +1108,7 @@ def check_gnn_get_all_neighbors(method):
|
|||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[node_list, neighbour_type], _ = parse_user_args(method, *args, **kwargs)
|
||||
[node_list, neighbour_type, _], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
check_gnn_list_or_ndarray(node_list, 'node_list')
|
||||
type_check(neighbour_type, (int,), "neighbour_type")
|
||||
|
|
|
@ -132,7 +132,7 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
|
|||
}
|
||||
}
|
||||
std::shared_ptr<Tensor> neighbors;
|
||||
s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], &neighbors);
|
||||
s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], OutputFormat::kNormal, &neighbors);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_TRUE(neighbors->shape().ToString() == "<10,6>");
|
||||
TensorRow features;
|
||||
|
@ -151,6 +151,47 @@ TEST_F(MindDataTestGNNGraph, TestGetAllNeighbors) {
|
|||
EXPECT_TRUE(features[2]->ToString() == "Tensor (shape: <10>, Type: int32)\n[1,2,3,1,4,3,5,3,5,4]");
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestGNNGraph, TestGetAllNeighborsSpecialFormat) {
|
||||
std::string path = "data/mindrecord/testGraphData/testdata";
|
||||
GraphDataImpl graph(path, 1);
|
||||
Status s = graph.Init();
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
|
||||
MetaInfo meta_info;
|
||||
s = graph.GetMetaInfo(&meta_info);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_TRUE(meta_info.node_type.size() == 2);
|
||||
|
||||
std::shared_ptr<Tensor> nodes;
|
||||
s = graph.GetAllNodes(meta_info.node_type[0], &nodes);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
std::vector<NodeIdType> node_list;
|
||||
for (auto itr = nodes->begin<NodeIdType>(); itr != nodes->end<NodeIdType>(); ++itr) {
|
||||
node_list.push_back(*itr);
|
||||
if (node_list.size() >= 10) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Check COO format
|
||||
std::shared_ptr<Tensor> neighbors_coo;
|
||||
s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], OutputFormat::kCoo, &neighbors_coo);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_TRUE(neighbors_coo->shape().ToString() == "<20,2>");
|
||||
EXPECT_TRUE(neighbors_coo->ToString() ==
|
||||
"Tensor (shape: <20,2>, Type: int32)\n"
|
||||
"[[101,201],[101,205],[101,206],[102,201],[102,202],[103,203],[103,205],[103,206],[103,207],[103,208],"
|
||||
"[105,204],[106,202],[106,203],[107,201],[107,203],[107,207],[108,208],[109,210],[110,201],[110,210]]");
|
||||
// Check CSR format
|
||||
std::shared_ptr<Tensor> neighbors_csr;
|
||||
s = graph.GetAllNeighbors(node_list, meta_info.node_type[1], OutputFormat::kCsr, &neighbors_csr);
|
||||
EXPECT_TRUE(s.IsOk());
|
||||
EXPECT_TRUE(neighbors_csr->shape().ToString() == "<30>");
|
||||
EXPECT_TRUE(
|
||||
neighbors_csr->ToString() ==
|
||||
"Tensor (shape: <30>, Type: int32)\n"
|
||||
"[0,3,5,10,10,11,13,16,17,18,201,205,206,201,202,203,205,206,207,208,204,202,203,201,203,207,208,210,201,210]");
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestGNNGraph, TestGetSampledNeighbors) {
|
||||
std::string path = "data/mindrecord/testGraphData/testdata";
|
||||
GraphDataImpl graph(path, 1);
|
||||
|
|
|
@ -18,6 +18,7 @@ import numpy as np
|
|||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
from mindspore.dataset.engine import SamplingStrategy
|
||||
from mindspore.dataset.engine import OutputFormat
|
||||
|
||||
DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
|
||||
SOCIAL_DATA_FILE = "../data/mindrecord/testGraphData/sns"
|
||||
|
@ -37,6 +38,23 @@ def test_graphdata_getfullneighbor():
|
|||
assert row_tensor[0].shape == (10, 6)
|
||||
|
||||
|
||||
def test_graphdata_getallneighbors_special_format():
|
||||
"""
|
||||
Test get all neighbors with special format
|
||||
"""
|
||||
logger.info('test get all neighbors with special format.\n')
|
||||
g = ds.GraphData(DATASET_FILE, 2)
|
||||
nodes = g.get_all_nodes(1)
|
||||
assert len(nodes) == 10
|
||||
|
||||
neighbor_coo = g.get_all_neighbors(nodes, 2, OutputFormat.COO)
|
||||
assert neighbor_coo.shape == (20, 2)
|
||||
|
||||
offset_table, neighbor_csr = g.get_all_neighbors(nodes, 2, OutputFormat.CSR)
|
||||
assert offset_table.shape == (10,)
|
||||
assert neighbor_csr.shape == (20,)
|
||||
|
||||
|
||||
def test_graphdata_getnodefeature_input_check():
|
||||
"""
|
||||
Test get node feature input check
|
||||
|
|
|
@ -21,6 +21,7 @@ import numpy as np
|
|||
import mindspore.dataset as ds
|
||||
from mindspore import log as logger
|
||||
from mindspore.dataset.engine import SamplingStrategy
|
||||
from mindspore.dataset.engine import OutputFormat
|
||||
|
||||
DATASET_FILE = "../data/mindrecord/testGraphData/testdata"
|
||||
|
||||
|
@ -105,6 +106,14 @@ def test_graphdata_distributed():
|
|||
[0, 1, 1, 0, 0], [0, 1, 0, 1, 0]]
|
||||
assert row_tensor[2].tolist() == [1, 2, 3, 1, 4, 3, 5, 3, 5, 4]
|
||||
|
||||
neighbor_normal = g.get_all_neighbors(nodes, 2, OutputFormat.NORMAL)
|
||||
assert neighbor_normal.shape == (10, 6)
|
||||
neighbor_coo = g.get_all_neighbors(nodes, 2, OutputFormat.COO)
|
||||
assert neighbor_coo.shape == (20, 2)
|
||||
offset_table, neighbor_csr = g.get_all_neighbors(nodes, 2, OutputFormat.CSR)
|
||||
assert offset_table.shape == (10,)
|
||||
assert neighbor_csr.shape == (20,)
|
||||
|
||||
edges = g.get_all_edges(0)
|
||||
assert edges.tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20,
|
||||
21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40]
|
||||
|
|
Loading…
Reference in New Issue