[go: nahoru, domu]

Skip to content

Commit

Permalink
add 5-fold permutation at inference time
Browse files Browse the repository at this point in the history
  • Loading branch information
ChloeXWang committed Dec 23, 2023
1 parent 1add3b2 commit 84c0486
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions graphgps/layer/gps_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,33 @@
import torch
from torch import Tensor

def sort_rand_gpu(pop_size, num_samples, neighbours):
# Randomly generate indices and select num_samples in neighbours
idx_select = torch.argsort(torch.rand(pop_size, device=neighbours.device))[:num_samples]
neighbours = neighbours[idx_select]
return neighbours

def augment_seq(edge_index, batch, num_k = -1):
unique_batches = torch.unique(batch)
# Initialize list to store permuted indices
permuted_indices = []
mask = []

for batch_index in unique_batches:
# Extract indices for the current batch
indices_in_batch = (batch == batch_index).nonzero().squeeze()
for k in indices_in_batch:
neighbours = edge_index[1][edge_index[0]==k]
if num_k > 0 and len(neighbours) > num_k:
neighbours = sort_rand_gpu(len(neighbours), num_k, neighbours)
permuted_indices.append(neighbours)
mask.append(torch.zeros(neighbours.shape, dtype=torch.bool, device=batch.device))
permuted_indices.append(torch.tensor([k], device=batch.device))
mask.append(torch.tensor([1], dtype=torch.bool, device=batch.device))
permuted_indices = torch.cat(permuted_indices)
mask = torch.cat(mask)
return permuted_indices.to(device=batch.device), mask.to(device=batch.device)

def lexsort(
keys: List[Tensor],
dim: int = -1,
Expand Down Expand Up @@ -251,6 +278,47 @@ def forward(self, batch):
h_dense, mask = to_dense_batch(h[h_ind_perm], batch.batch[h_ind_perm])
h_ind_perm_reverse = torch.argsort(h_ind_perm)
h_attn = self.self_attn(h_dense)[mask][h_ind_perm_reverse]
elif self.global_model_type == 'Mamba_Hybrid':
if batch.split == 'train':
h_ind_perm = permute_within_batch(batch.batch)
h_dense, mask = to_dense_batch(h[h_ind_perm], batch.batch[h_ind_perm])
h_ind_perm_reverse = torch.argsort(h_ind_perm)
h_attn = self.self_attn(h_dense)[mask][h_ind_perm_reverse]
else:
mamba_arr = []
for i in range(5):
h_ind_perm = permute_within_batch(batch.batch)
h_dense, mask = to_dense_batch(h[h_ind_perm], batch.batch[h_ind_perm])
h_ind_perm_reverse = torch.argsort(h_ind_perm)
h_attn = self.self_attn(h_dense)[mask][h_ind_perm_reverse]
mamba_arr.append(h_attn)
h_attn = sum(mamba_arr) / 5
elif self.global_model_type == 'Mamba_Hybrid_Degree':
if batch.split == 'train':
h_ind_perm = permute_within_batch(batch.batch)
deg = degree(batch.edge_index[0], batch.x.shape[0]).to(torch.long)
h_ind_perm_1 = lexsort([deg[h_ind_perm], batch.batch[h_ind_perm]])
h_ind_perm = h_ind_perm[h_ind_perm_1]
h_dense, mask = to_dense_batch(h[h_ind_perm], batch.batch[h_ind_perm])
h_ind_perm_reverse = torch.argsort(h_ind_perm)
h_attn = self.self_attn(h_dense)[mask][h_ind_perm_reverse]
else:
mamba_arr = []
for i in range(5):
h_ind_perm = permute_within_batch(batch.batch)
deg = degree(batch.edge_index[0], batch.x.shape[0]).to(torch.long)
h_ind_perm_1 = lexsort([deg[h_ind_perm], batch.batch[h_ind_perm]])
h_ind_perm = h_ind_perm[h_ind_perm_1]
h_dense, mask = to_dense_batch(h[h_ind_perm], batch.batch[h_ind_perm])
h_ind_perm_reverse = torch.argsort(h_ind_perm)
h_attn = self.self_attn(h_dense)[mask][h_ind_perm_reverse]
mamba_arr.append(h_attn)
h_attn = sum(mamba_arr) / 5
elif self.global_model_type == 'Mamba_Augment':
aug_idx, aug_mask = augment_seq(batch.edge_index, batch.batch, 3)
h_dense, mask = to_dense_batch(h[aug_idx], batch.batch[aug_idx])
aug_idx_reverse = torch.nonzero(aug_mask).squeeze()
h_attn = self.self_attn(h_dense)[mask][aug_idx_reverse]
else:
raise RuntimeError(f"Unexpected {self.global_model_type}")

Expand Down

0 comments on commit 84c0486

Please sign in to comment.