-
Notifications
You must be signed in to change notification settings - Fork 429
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix typos and add missing utils directory
- Loading branch information
Minkyung Baek
committed
Jul 2, 2021
1 parent
a98a404
commit 5510443
Showing
7 changed files
with
194 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,3 +6,4 @@ dependencies: | |
- tensorflow-gpu=1.14 | ||
- pandas | ||
- scikit-learn=0.24 | ||
- parallel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import warnings | ||
|
||
import dgl | ||
import torch | ||
|
||
|
||
def to_np(x): | ||
return x.cpu().detach().numpy() | ||
|
||
|
||
class PickleGraph: | ||
"""Lightweight graph object for easy pickling. Does not support batched graphs.""" | ||
|
||
def __init__(self, G=None, desired_keys=None): | ||
self.ndata = dict() | ||
self.edata = dict() | ||
|
||
if G is None: | ||
self.src = [] | ||
self.dst = [] | ||
else: | ||
if G.batch_size > 1: | ||
warnings.warn("Copying a batched graph to a PickleGraph is not supported. " | ||
"All node and edge data will be copied, but batching information will be lost.") | ||
|
||
self.src, self.dst = (to_np(idx) for idx in G.all_edges()) | ||
|
||
for k in G.ndata: | ||
if desired_keys is None or k in desired_keys: | ||
self.ndata[k] = to_np(G.ndata[k]) | ||
|
||
for k in G.edata: | ||
if desired_keys is None or k in desired_keys: | ||
self.edata[k] = to_np(G.edata[k]) | ||
|
||
def all_edges(self): | ||
return self.src, self.dst | ||
|
||
|
||
def copy_dgl_graph(G): | ||
if G.batch_size == 1: | ||
src, dst = G.all_edges() | ||
G2 = dgl.DGLGraph((src, dst)) | ||
for edge_key in list(G.edata.keys()): | ||
G2.edata[edge_key] = torch.clone(G.edata[edge_key]) | ||
for node_key in list(G.ndata.keys()): | ||
G2.ndata[node_key] = torch.clone(G.ndata[node_key]) | ||
return G2 | ||
else: | ||
list_of_graphs = dgl.unbatch(G) | ||
list_of_copies = [] | ||
|
||
for batch_G in list_of_graphs: | ||
list_of_copies.append(copy_dgl_graph(batch_G)) | ||
|
||
return dgl.batch(list_of_copies) | ||
|
||
|
||
def update_relative_positions(G, *, relative_position_key='d', absolute_position_key='x'): | ||
"""For each directed edge in the graph, calculate the relative position of the destination node with respect | ||
to the source node. Write the relative positions to the graph as edge data.""" | ||
src, dst = G.all_edges() | ||
absolute_positions = G.ndata[absolute_position_key] | ||
G.edata[relative_position_key] = absolute_positions[dst] - absolute_positions[src] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import os | ||
import sys | ||
import time | ||
import datetime | ||
import subprocess | ||
import numpy as np | ||
import torch | ||
|
||
from utils.utils_data import to_np | ||
|
||
|
||
_global_log = {} | ||
|
||
|
||
def try_mkdir(path): | ||
if not os.path.exists(path): | ||
os.makedirs(path) | ||
|
||
|
||
# @profile | ||
def make_logdir(checkpoint_dir, run_name=None): | ||
if run_name is None: | ||
now = datetime.datetime.now().strftime("%Y_%m_%d_%H.%M.%S") | ||
else: | ||
assert type(run_name) == str | ||
now = run_name | ||
|
||
log_dir = os.path.join(checkpoint_dir, now) | ||
try_mkdir(log_dir) | ||
return log_dir | ||
|
||
|
||
def count_parameters(model): | ||
""" | ||
count number of trainable parameters in module | ||
:param model: nn.Module instance | ||
:return: integer | ||
""" | ||
model_parameters = filter(lambda p: p.requires_grad, model.parameters()) | ||
n_params = sum([np.prod(p.size()) for p in model_parameters]) | ||
return n_params | ||
|
||
|
||
def write_info_file(model, FLAGS, UNPARSED_ARGV, wandb_log_dir=None): | ||
time_str = time.strftime("%m%d_%H%M%S") | ||
filename_log = "info_" + time_str + ".txt" | ||
filename_git_diff = "git_diff_" + time_str + ".txt" | ||
|
||
checkpoint_name = 'model' | ||
|
||
if wandb_log_dir: | ||
log_dir = wandb_log_dir | ||
os.mkdir(os.path.join(log_dir, 'checkpoints')) | ||
checkpoint_path = os.path.join(log_dir, 'checkpoints', checkpoint_name) | ||
elif FLAGS.restore: | ||
# set restore path | ||
assert FLAGS.run_name is not None | ||
log_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.run_name) | ||
checkpoint_path = os.path.join(log_dir, 'checkpoints', checkpoint_name) | ||
else: | ||
# makes logdir with time stamp | ||
log_dir = make_logdir(FLAGS.checkpoint_dir, FLAGS.run_name) | ||
os.mkdir(os.path.join(log_dir, 'checkpoints')) | ||
os.mkdir(os.path.join(log_dir, 'point_clouds')) | ||
# os.mkdir(os.path.join(log_dir, 'train_log')) | ||
# os.mkdir(os.path.join(log_dir, 'test_log')) | ||
checkpoint_path = os.path.join(log_dir, 'checkpoints', checkpoint_name) | ||
|
||
# writing arguments and git hash to info file | ||
file = open(os.path.join(log_dir, filename_log), "w") | ||
label = subprocess.check_output(["git", "describe", "--always"]).strip() | ||
file.write('latest git commit on this branch: ' + str(label) + '\n') | ||
file.write('\nFLAGS: \n') | ||
for key in sorted(vars(FLAGS)): | ||
file.write(key + ': ' + str(vars(FLAGS)[key]) + '\n') | ||
|
||
# count number of parameters | ||
if hasattr(model, 'parameters'): | ||
file.write('\nNumber of Model Parameters: ' + str(count_parameters(model)) + '\n') | ||
if hasattr(model, 'enc'): | ||
file.write('\nNumber of Encoder Parameters: ' + str( | ||
count_parameters(model.enc)) + '\n') | ||
if hasattr(model, 'dec'): | ||
file.write('\nNumber of Decoder Parameters: ' + str( | ||
count_parameters(model.dec)) + '\n') | ||
|
||
file.write('\nUNPARSED_ARGV:\n' + str(UNPARSED_ARGV)) | ||
file.write('\n\nBASH COMMAND: \n') | ||
bash_command = 'python' | ||
for argument in sys.argv: | ||
bash_command += (' ' + argument) | ||
file.write(bash_command) | ||
file.close() | ||
|
||
# write 'git diff' output into extra file | ||
subprocess.call(["git diff > " + os.path.join(log_dir, filename_git_diff)], shell=True) | ||
|
||
return log_dir, checkpoint_path | ||
|
||
|
||
def log_gradient_norm(tensor, variable_name): | ||
if variable_name not in _global_log: | ||
_global_log[variable_name] = [] | ||
|
||
def log_gradient_norm_inner(gradient): | ||
gradient_norm = torch.norm(gradient, dim=-1) | ||
_global_log[variable_name].append(to_np(gradient_norm)) | ||
|
||
tensor.register_hook(log_gradient_norm_inner) | ||
|
||
|
||
def get_average(variable_name): | ||
if variable_name not in _global_log: | ||
return float('nan') | ||
elif _global_log[variable_name]: | ||
overall_tensor = np.concatenate(_global_log[variable_name]) | ||
return np.mean(overall_tensor) | ||
else: | ||
return 0 | ||
|
||
|
||
def clear_data(variable_name): | ||
_global_log[variable_name] = [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
try: | ||
profile | ||
except NameError: | ||
def profile(func): | ||
return func |