[go: nahoru, domu]

Skip to content

Slimnios/SaGess

Repository files navigation

SaGess: Sampling Graph Discrete Denoising Diffusion Model

Official PyTorch implementation of SaGess.

SaGess is a discrete denoising diffusion model, which extends DiGress with a divide-and-conquer strategy to generate large synthetic networks by training on subgraph samples and reconstructing the overall graph.

Setting up the environment

Create anaconda environment

chmod +x *.sh

./install_conda_env.sh

Activate the environment

conda activate sagess

Setting up Wandb logs

By default, wandb stores the logs offline and would need to be synced after training. Make sure to set the 'entity' parameter in the setup_wandb() function located in src/run_sagess.py to be able to sync the logs to your account.

 'entity': 'wandb_username'

For online syncing, change the 'wandb' parameter in configs/general/general_default.yaml to 'online'.

Running the code

The main script can be launched as such:

 python src\run_sagess.py dataset=Cora

4 datasets from torch_geometric are supported: Cora, Wiki, EmailEUCore, ego-facebook and one custom SBM dataset loaded as a .pkl file. All the datasets are downloaded to or placed in the data folder.

Saved checkpoints, wandb log folder and other outputs can be found in the outputs folder.

Configs & parameters

Dataset specific configuration resides in configs/dataset/*.yaml files, including the number of subgraphs to train on, their size and sampling method.

Other default parameters for DiGress are found in configs/train/train_default.yaml, configs\model\discrete.yaml and configs\general\general_default.yaml.

Additional support for docker

To build and run the docker container, use docker_build.sh and run_docker_container.sh scripts respectively.

About

SaGess denoising diffusion model

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Languages