[go: nahoru, domu]

Skip to content

Commit

Permalink
fix typos and add missing utils directory
Browse files Browse the repository at this point in the history
  • Loading branch information
Minkyung Baek committed Jul 2, 2021
1 parent a98a404 commit 5510443
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 2 deletions.
1 change: 0 additions & 1 deletion RoseTTAFold-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ dependencies:
- biocore::blast-legacy=2.2.26
- hhsuite
- psipred=4.01
- parallel
- _libgcc_mutex=0.1=main
- _openmp_mutex=4.5=1_gnu
- blas=1.0=mkl
Expand Down
Binary file removed example/pyrosetta/tmp.npz
Binary file not shown.
1 change: 1 addition & 0 deletions folding-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ dependencies:
- tensorflow-gpu=1.14
- pandas
- scikit-learn=0.24
- parallel
2 changes: 1 addition & 1 deletion network/predict_pyRosetta.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def predict(self, a3m_fn, out_prefix, hhr_fn=None, atab_fn=None, window=150, shi
prob_s = [np.zeros((L,L,NBIN[i]), dtype=np.float32) for i in range(4)]
count = np.zeros((L,L), dtype=np.float32)
#
grids = np.arange(0, L-wnindow+shift, shift)
grids = np.arange(0, L-window+shift, shift)
ngrids = grids.shape[0]
print("ngrid: ", ngrids)
print("grids: ", grids)
Expand Down
64 changes: 64 additions & 0 deletions network/utils/utils_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import warnings

import dgl
import torch


def to_np(x):
return x.cpu().detach().numpy()


class PickleGraph:
"""Lightweight graph object for easy pickling. Does not support batched graphs."""

def __init__(self, G=None, desired_keys=None):
self.ndata = dict()
self.edata = dict()

if G is None:
self.src = []
self.dst = []
else:
if G.batch_size > 1:
warnings.warn("Copying a batched graph to a PickleGraph is not supported. "
"All node and edge data will be copied, but batching information will be lost.")

self.src, self.dst = (to_np(idx) for idx in G.all_edges())

for k in G.ndata:
if desired_keys is None or k in desired_keys:
self.ndata[k] = to_np(G.ndata[k])

for k in G.edata:
if desired_keys is None or k in desired_keys:
self.edata[k] = to_np(G.edata[k])

def all_edges(self):
return self.src, self.dst


def copy_dgl_graph(G):
if G.batch_size == 1:
src, dst = G.all_edges()
G2 = dgl.DGLGraph((src, dst))
for edge_key in list(G.edata.keys()):
G2.edata[edge_key] = torch.clone(G.edata[edge_key])
for node_key in list(G.ndata.keys()):
G2.ndata[node_key] = torch.clone(G.ndata[node_key])
return G2
else:
list_of_graphs = dgl.unbatch(G)
list_of_copies = []

for batch_G in list_of_graphs:
list_of_copies.append(copy_dgl_graph(batch_G))

return dgl.batch(list_of_copies)


def update_relative_positions(G, *, relative_position_key='d', absolute_position_key='x'):
"""For each directed edge in the graph, calculate the relative position of the destination node with respect
to the source node. Write the relative positions to the graph as edge data."""
src, dst = G.all_edges()
absolute_positions = G.ndata[absolute_position_key]
G.edata[relative_position_key] = absolute_positions[dst] - absolute_positions[src]
123 changes: 123 additions & 0 deletions network/utils/utils_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import os
import sys
import time
import datetime
import subprocess
import numpy as np
import torch

from utils.utils_data import to_np


_global_log = {}


def try_mkdir(path):
if not os.path.exists(path):
os.makedirs(path)


# @profile
def make_logdir(checkpoint_dir, run_name=None):
if run_name is None:
now = datetime.datetime.now().strftime("%Y_%m_%d_%H.%M.%S")
else:
assert type(run_name) == str
now = run_name

log_dir = os.path.join(checkpoint_dir, now)
try_mkdir(log_dir)
return log_dir


def count_parameters(model):
"""
count number of trainable parameters in module
:param model: nn.Module instance
:return: integer
"""
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
n_params = sum([np.prod(p.size()) for p in model_parameters])
return n_params


def write_info_file(model, FLAGS, UNPARSED_ARGV, wandb_log_dir=None):
time_str = time.strftime("%m%d_%H%M%S")
filename_log = "info_" + time_str + ".txt"
filename_git_diff = "git_diff_" + time_str + ".txt"

checkpoint_name = 'model'

if wandb_log_dir:
log_dir = wandb_log_dir
os.mkdir(os.path.join(log_dir, 'checkpoints'))
checkpoint_path = os.path.join(log_dir, 'checkpoints', checkpoint_name)
elif FLAGS.restore:
# set restore path
assert FLAGS.run_name is not None
log_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.run_name)
checkpoint_path = os.path.join(log_dir, 'checkpoints', checkpoint_name)
else:
# makes logdir with time stamp
log_dir = make_logdir(FLAGS.checkpoint_dir, FLAGS.run_name)
os.mkdir(os.path.join(log_dir, 'checkpoints'))
os.mkdir(os.path.join(log_dir, 'point_clouds'))
# os.mkdir(os.path.join(log_dir, 'train_log'))
# os.mkdir(os.path.join(log_dir, 'test_log'))
checkpoint_path = os.path.join(log_dir, 'checkpoints', checkpoint_name)

# writing arguments and git hash to info file
file = open(os.path.join(log_dir, filename_log), "w")
label = subprocess.check_output(["git", "describe", "--always"]).strip()
file.write('latest git commit on this branch: ' + str(label) + '\n')
file.write('\nFLAGS: \n')
for key in sorted(vars(FLAGS)):
file.write(key + ': ' + str(vars(FLAGS)[key]) + '\n')

# count number of parameters
if hasattr(model, 'parameters'):
file.write('\nNumber of Model Parameters: ' + str(count_parameters(model)) + '\n')
if hasattr(model, 'enc'):
file.write('\nNumber of Encoder Parameters: ' + str(
count_parameters(model.enc)) + '\n')
if hasattr(model, 'dec'):
file.write('\nNumber of Decoder Parameters: ' + str(
count_parameters(model.dec)) + '\n')

file.write('\nUNPARSED_ARGV:\n' + str(UNPARSED_ARGV))
file.write('\n\nBASH COMMAND: \n')
bash_command = 'python'
for argument in sys.argv:
bash_command += (' ' + argument)
file.write(bash_command)
file.close()

# write 'git diff' output into extra file
subprocess.call(["git diff > " + os.path.join(log_dir, filename_git_diff)], shell=True)

return log_dir, checkpoint_path


def log_gradient_norm(tensor, variable_name):
if variable_name not in _global_log:
_global_log[variable_name] = []

def log_gradient_norm_inner(gradient):
gradient_norm = torch.norm(gradient, dim=-1)
_global_log[variable_name].append(to_np(gradient_norm))

tensor.register_hook(log_gradient_norm_inner)


def get_average(variable_name):
if variable_name not in _global_log:
return float('nan')
elif _global_log[variable_name]:
overall_tensor = np.concatenate(_global_log[variable_name])
return np.mean(overall_tensor)
else:
return 0


def clear_data(variable_name):
_global_log[variable_name] = []
5 changes: 5 additions & 0 deletions network/utils/utils_profiling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
try:
profile
except NameError:
def profile(func):
return func

0 comments on commit 5510443

Please sign in to comment.