[go: nahoru, domu]

Skip to content

Commit

Permalink
update cluster+bucket
Browse files Browse the repository at this point in the history
  • Loading branch information
ChloeXWang committed Jan 11, 2024
2 parents 3501f94 + 815cf5a commit 27eb9cc
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 31 deletions.
6 changes: 5 additions & 1 deletion configs/GPS/cocosuperpixels-GPS.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ model:
type: GPSModel
loss_fun: weighted_cross_entropy
gt:
<<<<<<< HEAD
layer_type: CustomGatedGCN+Mamba_Hybrid_Degree_Bucket #Transformer #Performer
=======
layer_type: CustomGatedGCN+Mamba_Bucket #Transformer #Performer
>>>>>>> 815cf5a26a1097601438fcae49ae7f288ca469fd
layers: 4
n_heads: 8
dim_hidden: 96 # `gt.dim_hidden` must match `gnn.dim_inner`
Expand All @@ -54,6 +58,7 @@ gnn:
dropout: 0.0
agg: mean
normalize_adj: False
<<<<<<< HEAD
# optim:
# clip_grad_norm: True
# optimizer: adamW
Expand All @@ -70,7 +75,6 @@ optim:
max_epoch: 300
scheduler: cosine_with_warmup
num_warmup_epochs: 10

#optim:
# optimizer: adamW
# weight_decay: 0.0
Expand Down
21 changes: 3 additions & 18 deletions configs/GPS/peptides-func-GPS.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,6 @@ posenc_LapPE:
dim_pe: 16
layers: 2
raw_norm_type: none
# posenc_RWSE:
# enable: True
# kernel:
# times_func: range(1,17)
# model: Linear
# dim_pe: 20
# raw_norm_type: BatchNorm
train:
mode: custom
batch_size: 128
Expand All @@ -43,7 +36,7 @@ model:
loss_fun: cross_entropy
graph_pooling: mean
gt:
layer_type: CustomGatedGCN+Mamba_Cluster_Bucket #Mamba_Hybrid_Degree
layer_type: CustomGatedGCN+Mamba_Cluster #Mamba_Hybrid_Degree
n_heads: 4
dim_hidden: 96 # `gt.dim_hidden` must match `gnn.dim_inner`
dropout: 0.0
Expand All @@ -58,19 +51,11 @@ gnn:
batchnorm: True
act: relu
dropout: 0.0
# optim:
# clip_grad_norm: True
# optimizer: adamW
# weight_decay: 0.0
# base_lr: 0.0003
# max_epoch: 200
# scheduler: cosine_with_warmup
# num_warmup_epochs: 10
optim:
clip_grad_norm: True
optimizer: adamW
weight_decay: 0.01
base_lr: 0.001
weight_decay: 0.0
base_lr: 0.0003
max_epoch: 200
scheduler: cosine_with_warmup
num_warmup_epochs: 10
Expand Down
12 changes: 2 additions & 10 deletions configs/GPS/peptides-struct-GPS.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,11 @@ gnn:
batchnorm: True
act: relu
dropout: 0.0
#optim:
# clip_grad_norm: True
# optimizer: adamW
# weight_decay: 0.0
# base_lr: 0.0003
# max_epoch: 200
# scheduler: cosine_with_warmup
# num_warmup_epochs: 10
optim:
clip_grad_norm: True
optimizer: adamW
weight_decay: 0.01
base_lr: 0.001
weight_decay: 0.0
base_lr: 0.0003
max_epoch: 200
scheduler: cosine_with_warmup
num_warmup_epochs: 10
Expand Down
1 change: 0 additions & 1 deletion configs/GPS/vocsuperpixels-GPS.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
out_dir: results
metric_best: f1
wandb:
use: True
project: PascalVOC-SP
entity: tf-map
dataset:
Expand Down
4 changes: 3 additions & 1 deletion graphgps/layer/gps_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def forward(self, batch):
h_attn = self.self_attn(h_dense)[mask][h_ind_perm_reverse]
mamba_arr.append(h_attn)
h_attn = sum(mamba_arr) / 5

elif 'Mamba_Hybrid_Degree' in self.global_model_type:
if batch.split == 'train':
h_ind_perm = permute_within_batch(batch.batch)
Expand Down Expand Up @@ -379,6 +380,7 @@ def forward(self, batch):
#h_attn = self.self_attn(h_dense)[mask][h_ind_perm_reverse]
mamba_arr.append(h_attn)
h_attn = sum(mamba_arr) / 5

elif self.global_model_type == 'Mamba_Eigen':
deg = degree(batch.edge_index[0], batch.x.shape[0]).to(torch.long)
centrality = batch.EigCentrality
Expand Down Expand Up @@ -420,7 +422,7 @@ def forward(self, batch):
unique_cluster_n = len(torch.unique(batch.LouvainCluster))
permuted_louvain = torch.zeros(batch.LouvainCluster.shape).long().to(batch.LouvainCluster.device)
random_permute = torch.randperm(unique_cluster_n+1).long().to(batch.LouvainCluster.device)
for i in range(len(torch.unique(batch.LouvainCluster))):
for i in range(unique_cluster_n):
indices = torch.nonzero(batch.LouvainCluster == i).squeeze()
permuted_louvain[indices] = random_permute[i]
#h_ind_perm_1 = lexsort([deg[h_ind_perm], permuted_louvain[h_ind_perm], batch.batch[h_ind_perm]])
Expand Down

0 comments on commit 27eb9cc

Please sign in to comment.