[go: nahoru, domu]

Skip to content

Commit

Permalink
Fixed evaluation part
Browse files Browse the repository at this point in the history
  • Loading branch information
VincentStimper committed Oct 7, 2021
1 parent ae2be85 commit 1ae2128
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions experiments/train_uci.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,10 @@
log_p_sum = 0
num_nan = 0
for x in iter(validate_loader):
x = x.to(device)
log_p = model.log_prob(x)
log_p_np = log_p.cpu().numpy()
with torch.no_grad():
x = x.to(device)
log_p = model.log_prob(x)
log_p_np = log_p.cpu().detach().numpy()
isfinite = np.isfinite(log_p_np)
num_nan += np.sum(~isfinite)
log_p_sum += np.sum(log_p_np[isfinite])
Expand All @@ -286,9 +287,10 @@
log_p_sum = 0
num_nan = 0
for x in iter(test_loader):
x = x.to(device)
log_p = model.log_prob(x)
log_p_np = log_p.cpu().numpy()
with torch.no_grad():
x = x.to(device)
log_p = model.log_prob(x)
log_p_np = log_p.cpu().detach().numpy()
isfinite = np.isfinite(log_p_np)
num_nan += np.sum(~isfinite)
log_p_sum += np.sum(log_p_np[isfinite])
Expand Down Expand Up @@ -319,9 +321,10 @@
log_p_sum = 0
num_nan = 0
for x in iter(validate_loader):
x = x.to(device)
log_p = model.log_prob(x)
log_p_np = log_p.cpu().numpy()
with torch.no_grad():
x = x.to(device)
log_p = model.log_prob(x)
log_p_np = log_p.cpu().detach().numpy()
isfinite = np.isfinite(log_p_np)
num_nan += np.sum(~isfinite)
log_p_sum += np.sum(log_p_np[isfinite])
Expand All @@ -339,9 +342,10 @@
log_p_sum = 0
num_nan = 0
for x in iter(test_loader):
x = x.to(device)
log_p = model.log_prob(x)
log_p_np = log_p.cpu().numpy()
with torch.no_grad():
x = x.to(device)
log_p = model.log_prob(x)
log_p_np = log_p.cpu().detach().numpy()
isfinite = np.isfinite(log_p_np)
num_nan += np.sum(~isfinite)
log_p_sum += np.sum(log_p_np[isfinite])
Expand Down

0 comments on commit 1ae2128

Please sign in to comment.