[go: nahoru, domu]

Skip to content

Commit

Permalink
add pseudo-permute with noise
Browse files Browse the repository at this point in the history
  • Loading branch information
ChloeXWang committed Jan 11, 2024
1 parent 27eb9cc commit 5b1b14b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
2 changes: 1 addition & 1 deletion configs/GPS/peptides-func-GPS.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ model:
loss_fun: cross_entropy
graph_pooling: mean
gt:
layer_type: CustomGatedGCN+Mamba_Cluster #Mamba_Hybrid_Degree
layer_type: CustomGatedGCN+Mamba_Hybrid_Degree_Noise
n_heads: 4
dim_hidden: 96 # `gt.dim_hidden` must match `gnn.dim_inner`
dropout: 0.0
Expand Down
24 changes: 23 additions & 1 deletion graphgps/layer/gps_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def forward(self, batch):
mamba_arr.append(h_attn)
h_attn = sum(mamba_arr) / 5

elif 'Mamba_Hybrid_Degree' in self.global_model_type:
elif 'Mamba_Hybrid_Degree' == self.global_model_type:
if batch.split == 'train':
h_ind_perm = permute_within_batch(batch.batch)
#h_ind_perm = permute_nodes_within_identity(batch.batch)
Expand Down Expand Up @@ -380,6 +380,28 @@ def forward(self, batch):
#h_attn = self.self_attn(h_dense)[mask][h_ind_perm_reverse]
mamba_arr.append(h_attn)
h_attn = sum(mamba_arr) / 5

elif 'Mamba_Hybrid_Degree_Noise' == self.global_model_type:
if batch.split == 'train':
deg = degree(batch.edge_index[0], batch.x.shape[0]).to(torch.float)
#deg_noise = torch.std(deg)*torch.randn(deg.shape).to(deg.device)
deg_noise = torch.randn(deg.shape).to(deg.device)
h_ind_perm = lexsort([deg+deg_noise, batch.batch])
h_dense, mask = to_dense_batch(h[h_ind_perm], batch.batch[h_ind_perm])
h_ind_perm_reverse = torch.argsort(h_ind_perm)
h_attn = self.self_attn(h_dense)[mask][h_ind_perm_reverse]
else:
mamba_arr = []
for i in range(5):
deg = degree(batch.edge_index[0], batch.x.shape[0]).to(torch.float)
#deg_noise = torch.std(deg)*torch.randn(deg.shape).to(deg.device)
deg_noise = torch.randn(deg.shape).to(deg.device)
h_ind_perm = lexsort([deg+deg_noise, batch.batch])
h_dense, mask = to_dense_batch(h[h_ind_perm], batch.batch[h_ind_perm])
h_ind_perm_reverse = torch.argsort(h_ind_perm)
h_attn = self.self_attn(h_dense)[mask][h_ind_perm_reverse]
mamba_arr.append(h_attn)
h_attn = sum(mamba_arr) / 5

elif self.global_model_type == 'Mamba_Eigen':
deg = degree(batch.edge_index[0], batch.x.shape[0]).to(torch.long)
Expand Down

0 comments on commit 5b1b14b

Please sign in to comment.