[go: nahoru, domu]

Skip to content

Commit

Permalink
Graphormer-style attention bias option in GPS
Browse files Browse the repository at this point in the history
  • Loading branch information
rampasek committed Feb 8, 2023
1 parent bfb6651 commit 78145e2
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 20 deletions.
19 changes: 17 additions & 2 deletions graphgps/encoder/composed_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from graphgps.encoder.type_dict_encoder import TypeDictNodeEncoder
from graphgps.encoder.linear_node_encoder import LinearNodeEncoder
from graphgps.encoder.equivstable_laplace_pos_encoder import EquivStableLapPENodeEncoder
from graphgps.encoder.graphormer_encoder import GraphormerEncoder


def concat_node_encoders(encoder_classes, pe_enc_names):
Expand Down Expand Up @@ -47,7 +48,7 @@ def __init__(self, dim_emb):
else:
# PE dims can only be gathered once the cfg is loaded.
enc2_dim_pe = getattr(cfg, f"posenc_{self.enc2_name}").dim_pe

self.encoder1 = self.enc1_cls(dim_emb - enc2_dim_pe)
self.encoder2 = self.enc2_cls(dim_emb, expand_x=False)

Expand Down Expand Up @@ -112,7 +113,8 @@ def forward(self, batch):
'HKdiagSE': HKdiagSENodeEncoder,
'ElstaticSE': ElstaticSENodeEncoder,
'SignNet': SignNetNodeEncoder,
'EquivStableLapPE': EquivStableLapPENodeEncoder}
'EquivStableLapPE': EquivStableLapPENodeEncoder,
'GraphormerBias': GraphormerEncoder}

# Concat dataset-specific and PE encoders.
for ds_enc_name, ds_enc_cls in ds_encs.items():
Expand All @@ -138,3 +140,16 @@ def forward(self, batch):
concat_node_encoders([ds_enc_cls, SignNetNodeEncoder, RWSENodeEncoder],
['SignNet', 'RWSE'])
)

# Combine GraphormerBias with LapPE or RWSE positional encodings.
for ds_enc_name, ds_enc_cls in ds_encs.items():
register_node_encoder(
f"{ds_enc_name}+GraphormerBias+LapPE",
concat_node_encoders([ds_enc_cls, GraphormerEncoder, LapPENodeEncoder],
['GraphormerBias', 'LapPE'])
)
register_node_encoder(
f"{ds_enc_name}+GraphormerBias+RWSE",
concat_node_encoders([ds_enc_cls, GraphormerEncoder, RWSENodeEncoder],
['GraphormerBias', 'RWSE'])
)
6 changes: 3 additions & 3 deletions graphgps/encoder/kernel_pos_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ def __init__(self, dim_emb, expand_x=True):
norm_type = pecfg.raw_norm_type.lower() # Raw PE normalization layer type
self.pass_as_var = pecfg.pass_as_var # Pass PE also as a separate variable

if dim_emb - dim_pe < 1:
if dim_emb - dim_pe < 0: # formerly 1, but you could have zero feature size
raise ValueError(f"PE dim size {dim_pe} is too large for "
f"desired embedding size of {dim_emb}.")

if expand_x:
if expand_x and dim_emb - dim_pe > 0:
self.linear_x = nn.Linear(dim_in, dim_emb - dim_pe)
self.expand_x = expand_x
self.expand_x = expand_x and dim_emb - dim_pe > 0

if norm_type == 'batchnorm':
self.raw_norm = nn.BatchNorm1d(num_rw_steps)
Expand Down
6 changes: 3 additions & 3 deletions graphgps/encoder/laplace_pos_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ def __init__(self, dim_emb, expand_x=True):
norm_type = pecfg.raw_norm_type.lower() # Raw PE normalization layer type
self.pass_as_var = pecfg.pass_as_var # Pass PE also as a separate variable

if dim_emb - dim_pe < 1:
if dim_emb - dim_pe < 0: # formerly 1, but you could have zero feature size
raise ValueError(f"LapPE size {dim_pe} is too large for "
f"desired embedding size of {dim_emb}.")

if expand_x:
if expand_x and dim_emb - dim_pe > 0:
self.linear_x = nn.Linear(dim_in, dim_emb - dim_pe)
self.expand_x = expand_x
self.expand_x = expand_x and dim_emb - dim_pe > 0

# Initial projection of eigenvalue and the node's eigenvector value
self.linear_A = nn.Linear(2, dim_pe)
Expand Down
41 changes: 33 additions & 8 deletions graphgps/layer/gps_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,30 @@ def __init__(self, dim_h,
self.activation = register.act_dict[act]

self.log_attn_weights = log_attn_weights
if log_attn_weights and global_model_type != 'Transformer':
if log_attn_weights and global_model_type not in ['Transformer',
'BiasedTransformer']:
raise NotImplementedError(
"Logging of attention weights is only supported for "
"Transformer global attention model."
f"Logging of attention weights is not supported "
f"for '{global_model_type}' global attention model."
)

# Local message-passing model.
self.local_gnn_with_edge_attr = True
if local_gnn_type == 'None':
self.local_model = None

# MPNNs without edge attributes support.
elif local_gnn_type == "GCN":
self.local_gnn_with_edge_attr = False
self.local_model = pygnn.GCNConv(dim_h, dim_h)
elif local_gnn_type == 'GIN':
self.local_gnn_with_edge_attr = False
gin_nn = nn.Sequential(Linear_pyg(dim_h, dim_h),
self.activation(),
Linear_pyg(dim_h, dim_h))
self.local_model = pygnn.GINConv(gin_nn)

# MPNNs supporting also edge attributes.
elif local_gnn_type == 'GENConv':
self.local_model = pygnn.GENConv(dim_h, dim_h)
elif local_gnn_type == 'GINE':
Expand Down Expand Up @@ -86,7 +101,7 @@ def __init__(self, dim_h,
# Global attention transformer-style model.
if global_model_type == 'None':
self.self_attn = None
elif global_model_type == 'Transformer':
elif global_model_type in ['Transformer', 'BiasedTransformer']:
self.self_attn = torch.nn.MultiheadAttention(
dim_h, num_heads, dropout=self.attn_dropout, batch_first=True)
# self.global_model = torch.nn.TransformerEncoderLayer(
Expand Down Expand Up @@ -158,11 +173,18 @@ def forward(self, batch):
h_local = local_out.x
batch.edge_attr = local_out.edge_attr
else:
if self.equivstable_pe:
h_local = self.local_model(h, batch.edge_index, batch.edge_attr,
batch.pe_EquivStableLapPE)
if self.local_gnn_with_edge_attr:
if self.equivstable_pe:
h_local = self.local_model(h,
batch.edge_index,
batch.edge_attr,
batch.pe_EquivStableLapPE)
else:
h_local = self.local_model(h,
batch.edge_index,
batch.edge_attr)
else:
h_local = self.local_model(h, batch.edge_index, batch.edge_attr)
h_local = self.local_model(h, batch.edge_index)
h_local = self.dropout_local(h_local)
h_local = h_in1 + h_local # Residual connection.

Expand All @@ -177,6 +199,9 @@ def forward(self, batch):
h_dense, mask = to_dense_batch(h, batch.batch)
if self.global_model_type == 'Transformer':
h_attn = self._sa_block(h_dense, None, ~mask)[mask]
elif self.global_model_type == 'BiasedTransformer':
# Use Graphormer-like conditioning, requires `batch.attn_bias`.
h_attn = self._sa_block(h_dense, batch.attn_bias, ~mask)[mask]
elif self.global_model_type == 'Performer':
h_attn = self.self_attn(h_dense, mask=mask)[mask]
elif self.global_model_type == 'BigBird':
Expand Down
15 changes: 11 additions & 4 deletions graphgps/network/gps_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, dim_in):
self.node_encoder_bn = BatchNorm1dNode(
new_layer_config(cfg.gnn.dim_inner, -1, -1, has_act=False,
has_bias=False, cfg=cfg))
# Update dim_in to reflect the new dimension fo the node features
# Update dim_in to reflect the new dimension of the node features
self.dim_in = cfg.gnn.dim_inner
if cfg.dataset.edge_encoder:
# Hard-limit max edge dim for PNA.
Expand All @@ -53,7 +53,10 @@ def forward(self, batch):

@register_network('GPSModel')
class GPSModel(torch.nn.Module):
"""Multi-scale graph x-former.
"""General-Powerful-Scalable graph transformer.
https://arxiv.org/abs/2205.12454
Rampasek, L., Galkin, M., Dwivedi, V. P., Luu, A. T., Wolf, G., & Beaini, D.
Recipe for a general, powerful, scalable graph transformer. (NeurIPS 2022)
"""

def __init__(self, dim_in, dim_out):
Expand All @@ -66,8 +69,12 @@ def __init__(self, dim_in, dim_out):
dim_in, cfg.gnn.dim_inner, cfg.gnn.layers_pre_mp)
dim_in = cfg.gnn.dim_inner

assert cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in, \
"The inner and hidden dims must match."
if not cfg.gt.dim_hidden == cfg.gnn.dim_inner == dim_in:
raise ValueError(
f"The inner and hidden dims must match: "
f"embed_dim={cfg.gt.dim_hidden} dim_inner={cfg.gnn.dim_inner} "
f"dim_in={dim_in}"
)

try:
local_gnn_type, global_model_type = cfg.gt.layer_type.split('+')
Expand Down

0 comments on commit 78145e2

Please sign in to comment.