Skip to content

[Model] Implement model Unimp#83

Open
yangcd-bupt wants to merge 25 commits into
BUPT-GAMMA:mainfrom
yangcd-bupt:main
Open

[Model] Implement model Unimp#83
yangcd-bupt wants to merge 25 commits into
BUPT-GAMMA:mainfrom
yangcd-bupt:main

Conversation

@yangcd-bupt

Copy link
Copy Markdown

support ogbn dataset

Description

Checklist

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [$CATEGORY] (such as [NN], [Model], [Doc], [Feature]])
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented
  • To the best of my knowledge, examples are either not affected by this change,
    or have been fixed to be compatible with this change
  • Related issue is referred in this PR

Changes

@yangcd-bupt

Copy link
Copy Markdown
Author

get ogb node dataset via OgbNodeDataset class

@Zhanghyi Zhanghyi left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

增加一些测试代码,对于graph,link,node 各自选择一个最小的数据集写一下单元测试的代码 放在 https://github.com/BUPT-GAMMA/GammaGL/tree/main/tests/datasets 下面

Comment thread gammagl/datasets/ogb_graph.py Outdated
import os.path as osp
import numpy as np
from gammagl.data import InMemoryDataset
from gammgl.utils.ogb_url import decide_download, download_url, extract_zip

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gammgl -> gammagl

Comment thread gammagl/utils/ogb_url.py Outdated
@@ -0,0 +1,91 @@
import urllib.request as ur

@Zhanghyi Zhanghyi Sep 15, 2022

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment thread gammagl/datasets/ogb_graph.py Outdated

# check if previously-downloaded folder exists.
# If so, use that one.
if osp.exists(osp.join(root, self.dir_name + '_pyg')):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pyg相关的代码改成gammagl

Comment thread tests/datasets/test_ogb_graph.py Outdated
print(data)
print(data[0])

test() No newline at end of file

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace the function name "test" with "test_ogbgraphdataset", which makes it easier to distinguish the functions from other files.

Comment thread tests/datasets/test_ogb_link.py Outdated
data=OgbLinkDataset('ogbl-ppa')
print(data[0])

test() No newline at end of file

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

Comment thread tests/datasets/test_ogb_node.py Outdated
data=OgbNodeDataset('ogbn-arxiv')
print(data[0])

test() No newline at end of file

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

Comment thread gammagl/datasets/ogb_node.py Outdated
from gammagl.data import InMemoryDataset
from gammagl.data.download import download_url
from gammagl.data.extract import extract_zip
from gammagl.io.read_ogb import read_node_label_hetero, read_graph, read_heterograph, read_nodesplitidx_split_hetero

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

中文逗号

Comment thread tests/datasets/test_ogbgraphdataset.py Outdated
print(data)
print(data[0])

test() No newline at end of file

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Replace the function name "test" with "test_ogbgraphdataset", which makes it easier to distinguish the functions from other files.
  2. Using "assert" to check the correctness of some variable. e.g. the "feature.shape[0]" and "num_nodes" should be equal.

Comment thread tests/datasets/test_ogblinkdataset.py Outdated
data=OgbLinkDataset('ogbl-ppa')
print(data[0])

test() No newline at end of file

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

Comment thread tests/datasets/test_ogbnodedataset.py Outdated
data=OgbNodeDataset('ogbn-arxiv')
print(data[0])

test() No newline at end of file

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

Comment thread examples/unimp/unimp_trainer.py Outdated
return loss


class MultiHead(MessagePassing):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put this section into gammagl.layers.conv.unimp_conv.py. If you put this into this file, users will not be able to use this function.

Comment thread examples/unimp/unimp_trainer.py Outdated
return x


class Unimp(tlx.nn.Module):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put this section into gammagl.models.unimp.py. If you put this into this file, users will not be able to use this function.

Comment thread gammagl/models/unimp.py
super(Unimp, self).__init__()

out_layer1=int(dataset.num_node_features/2)
self.layer1=MultiHead(dataset.num_node_features+1, out_layer1, 4,dataset[0].num_nodes)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think users can choose their n_heads, so change this to let users choose their own n_heads instead of a fixed one.

Comment thread gammagl/models/unimp.py
self.norm1=nn.LayerNorm(out_layer1)
self.relu1=nn.ReLU()

self.layer2=MultiHead(out_layer1, dataset.num_classes, 4,dataset[0].num_nodes)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

Comment thread gammagl/models/unimp.py
import tlx.nn as nn
from gammagl.layers import MultiHead

class Unimp(tlx.nn.Module):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add rst doc here refer to this link: https://docs.qq.com/pdf/DUXRTTU9tUnB1WnFB.

@dddg617 dddg617 changed the title Create ogb_node.py [Model] Implement model Unimp Oct 28, 2022
Comment thread examples/unimp/unimp_trainer.py Outdated
return loss


def forward(self, x, edge_index):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why there are two forward functions?

Comment thread gammagl/layers/conv/multi_head.py Outdated

alpha = self.dropout(segment_softmax(weight, node_dst, self.num_nodes))
x = tlx.gather(x, node_src) * tlx.expand_dims(alpha, -1)
return x

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this function means Unimp? So what is the difference between this model and model GAT?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants