[Model] Implement model Unimp#83
Conversation
support ogbn dataset
|
get ogb node dataset via OgbNodeDataset class |
Zhanghyi
left a comment
There was a problem hiding this comment.
增加一些测试代码,对于graph,link,node 各自选择一个最小的数据集写一下单元测试的代码 放在 https://github.com/BUPT-GAMMA/GammaGL/tree/main/tests/datasets 下面
| import os.path as osp | ||
| import numpy as np | ||
| from gammagl.data import InMemoryDataset | ||
| from gammgl.utils.ogb_url import decide_download, download_url, extract_zip |
| @@ -0,0 +1,91 @@ | |||
| import urllib.request as ur | |||
There was a problem hiding this comment.
|
|
||
| # check if previously-downloaded folder exists. | ||
| # If so, use that one. | ||
| if osp.exists(osp.join(root, self.dir_name + '_pyg')): |
| print(data) | ||
| print(data[0]) | ||
|
|
||
| test() No newline at end of file |
There was a problem hiding this comment.
Replace the function name "test" with "test_ogbgraphdataset", which makes it easier to distinguish the functions from other files.
| data=OgbLinkDataset('ogbl-ppa') | ||
| print(data[0]) | ||
|
|
||
| test() No newline at end of file |
| data=OgbNodeDataset('ogbn-arxiv') | ||
| print(data[0]) | ||
|
|
||
| test() No newline at end of file |
| from gammagl.data import InMemoryDataset | ||
| from gammagl.data.download import download_url | ||
| from gammagl.data.extract import extract_zip | ||
| from gammagl.io.read_ogb import read_node_label_hetero, read_graph, read_heterograph, read_nodesplitidx_split_hetero |
| print(data) | ||
| print(data[0]) | ||
|
|
||
| test() No newline at end of file |
There was a problem hiding this comment.
- Replace the function name "test" with "test_ogbgraphdataset", which makes it easier to distinguish the functions from other files.
- Using "assert" to check the correctness of some variable. e.g. the "feature.shape[0]" and "num_nodes" should be equal.
| data=OgbLinkDataset('ogbl-ppa') | ||
| print(data[0]) | ||
|
|
||
| test() No newline at end of file |
| data=OgbNodeDataset('ogbn-arxiv') | ||
| print(data[0]) | ||
|
|
||
| test() No newline at end of file |
| return loss | ||
|
|
||
|
|
||
| class MultiHead(MessagePassing): |
There was a problem hiding this comment.
Put this section into gammagl.layers.conv.unimp_conv.py. If you put this into this file, users will not be able to use this function.
| return x | ||
|
|
||
|
|
||
| class Unimp(tlx.nn.Module): |
There was a problem hiding this comment.
Put this section into gammagl.models.unimp.py. If you put this into this file, users will not be able to use this function.
| super(Unimp, self).__init__() | ||
|
|
||
| out_layer1=int(dataset.num_node_features/2) | ||
| self.layer1=MultiHead(dataset.num_node_features+1, out_layer1, 4,dataset[0].num_nodes) |
There was a problem hiding this comment.
I think users can choose their n_heads, so change this to let users choose their own n_heads instead of a fixed one.
| self.norm1=nn.LayerNorm(out_layer1) | ||
| self.relu1=nn.ReLU() | ||
|
|
||
| self.layer2=MultiHead(out_layer1, dataset.num_classes, 4,dataset[0].num_nodes) |
| import tlx.nn as nn | ||
| from gammagl.layers import MultiHead | ||
|
|
||
| class Unimp(tlx.nn.Module): |
There was a problem hiding this comment.
Please add rst doc here refer to this link: https://docs.qq.com/pdf/DUXRTTU9tUnB1WnFB.
| return loss | ||
|
|
||
|
|
||
| def forward(self, x, edge_index): |
There was a problem hiding this comment.
Why there are two forward functions?
|
|
||
| alpha = self.dropout(segment_softmax(weight, node_dst, self.num_nodes)) | ||
| x = tlx.gather(x, node_src) * tlx.expand_dims(alpha, -1) | ||
| return x |
There was a problem hiding this comment.
Is this function means Unimp? So what is the difference between this model and model GAT?
support ogbn dataset
Description
Checklist
Please feel free to remove inapplicable items for your PR.
or have been fixed to be compatible with this change
Changes