[go: nahoru, domu]

Skip to content

Commit

Permalink
Merge pull request #543 from refraction-ray/ad_eig
Browse files Browse the repository at this point in the history
Add np.linalg.eig vjp
  • Loading branch information
j-towns committed Nov 18, 2019
2 parents b37252a + be6efaa commit c6f630a
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 16 deletions.
31 changes: 31 additions & 0 deletions autograd/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ def T(x): return anp.swapaxes(x, -1, -2)
# batched diag
_diag = lambda a: anp.eye(a.shape[-1])*a

# batched diagonal, similar to matrix_diag in tensorflow
def _matrix_diag(a):
reps = anp.array(a.shape)
reps[:-1] = 1
reps[-1] = a.shape[-1]
newshape = list(a.shape) + [a.shape[-1]]
return _diag(anp.tile(a, reps).reshape(newshape))

# add two dimensions to the end of x
def add2d(x): return anp.reshape(x, anp.shape(x) + (1, 1))

Expand Down Expand Up @@ -141,6 +149,29 @@ def vjp(g):
return vjp
defvjp(eigh, grad_eigh)

# https://arxiv.org/pdf/1701.00392.pdf Eq(4.77)
# Note the formula from Sec3.1 in https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf is incomplete
def grad_eig(ans, x):
"""Gradient of a general square (complex valued) matrix"""
e, u = ans # eigenvalues as 1d array, eigenvectors in columns
n = e.shape[-1]
def vjp(g):
ge, gu = g
ge = _matrix_diag(ge)
f = 1/(e[..., anp.newaxis, :] - e[..., :, anp.newaxis] + 1.e-20)
f -= _diag(f)
ut = anp.swapaxes(u, -1, -2)
r1 = f * _dot(ut, gu)
r2 = -f * (_dot(_dot(ut, anp.conj(u)), anp.real(_dot(ut,gu)) * anp.eye(n)))
r = _dot(_dot(inv(ut), ge + r1 + r2), ut)
if not anp.iscomplexobj(x):
r = anp.real(r)
# the derivative is still complex for real input (imaginary delta is allowed), real output
# but the derivative should be real in real input case when imaginary delta is forbidden
return r
return vjp
defvjp(eig, grad_eig)

def grad_cholesky(L, A):
# Based on Iain Murray's note http://arxiv.org/abs/1602.07527
# scipy's dtrtrs wrapper, solve_triangular, doesn't broadcast along leading
Expand Down
48 changes: 32 additions & 16 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,32 @@ def fun(x):
mat = npr.randn(D, D) + 1j*npr.randn(D, D)
check_grads(fun)(mat)

# Note eigenvalues and eigenvectors for real matrix can still be complex
def test_eig_real():
def fun(x):
w, v = np.linalg.eig(x)
return tuple((np.abs(w), np.abs(v)))
D = 8
mat = npr.randn(D, D)
check_grads(fun)(mat)

def test_eig_complex():
def fun(x):
w, v = np.linalg.eig(x)
return tuple((w, np.abs(v)))
D = 8
mat = npr.randn(D, D) + 1.j * npr.randn(D, D)
check_grads(fun)(mat)

def test_eig_batched():
def fun(x):
w, v = np.linalg.eig(x)
return tuple((w, np.abs(v)))
D = 8
b = 5
mat = npr.randn(b, D, D) + 1.j * npr.randn(b, D, D)
check_grads(fun)(mat)

def test_cholesky():
fun = lambda A: np.linalg.cholesky(A)
check_symmetric_matrix_grads(fun)(rand_psd(6))
Expand All @@ -248,7 +274,7 @@ def test_svd_wide_2d():
def fun(x):
u, s, v = np.linalg.svd(x, full_matrices=False)
return tuple((u, s, v))
return grad(fun)(x)

m = 3
n = 5
mat = npr.randn(m, n)
Expand All @@ -258,7 +284,7 @@ def test_svd_wide_2d_complex():
def fun(x):
u, s, v = np.linalg.svd(x, full_matrices=False)
return tuple((np.abs(u), s, np.abs(v)))
return grad(fun)(x)

m = 3
n = 5
mat = npr.randn(m, n) + 1j * npr.randn(m, n)
Expand All @@ -268,7 +294,6 @@ def test_svd_wide_3d():
def fun(x):
u, s, v = np.linalg.svd(x, full_matrices=False)
return tuple((u, s, v))
return grad(fun)(x)

k = 4
m = 3
Expand All @@ -280,7 +305,6 @@ def test_svd_wide_3d_complex():
def fun(x):
u, s, v = np.linalg.svd(x, full_matrices=False)
return tuple((np.abs(u), s, np.abs(v)))
return grad(fun)(x)

k = 4
m = 3
Expand All @@ -292,7 +316,7 @@ def test_svd_square_2d():
def fun(x):
u, s, v = np.linalg.svd(x, full_matrices=False)
return tuple((u, s, v))
return grad(fun)(x)

m = 4
n = 4
mat = npr.randn(m, n)
Expand All @@ -302,7 +326,7 @@ def test_svd_square_2d_complex():
def fun(x):
u, s, v = np.linalg.svd(x, full_matrices=False)
return tuple((np.abs(u), s, np.abs(v)))
return grad(fun)(x)

m = 4
n = 4
mat = npr.randn(m, n) + 1j * npr.randn(m, n)
Expand All @@ -312,7 +336,6 @@ def test_svd_square_3d():
def fun(x):
u, s, v = np.linalg.svd(x, full_matrices=False)
return tuple((u, s, v))
return grad(fun)(x)

k = 3
m = 4
Expand All @@ -324,7 +347,6 @@ def test_svd_square_3d_complex():
def fun(x):
u, s, v = np.linalg.svd(x, full_matrices=False)
return tuple((np.abs(u), s, np.abs(v)))
return grad(fun)(x)

k = 3
m = 4
Expand All @@ -336,7 +358,7 @@ def test_svd_tall_2d():
def fun(x):
u, s, v = np.linalg.svd(x, full_matrices=False)
return tuple((u, s, v))
return grad(fun)(x)

m = 5
n = 3
mat = npr.randn(m, n)
Expand All @@ -346,7 +368,7 @@ def test_svd_tall_2d_complex():
def fun(x):
u, s, v = np.linalg.svd(x, full_matrices=False)
return tuple((np.abs(u), s, np.abs(v)))
return grad(fun)(x)

m = 5
n = 3
mat = npr.randn(m, n) + 1j * npr.randn(m, n)
Expand All @@ -356,7 +378,6 @@ def test_svd_tall_3d():
def fun(x):
u, s, v = np.linalg.svd(x, full_matrices=False)
return tuple((u, s, v))
return grad(fun)(x)

k = 4
m = 5
Expand All @@ -368,7 +389,6 @@ def test_svd_tall_3d_complex():
def fun(x):
u, s, v = np.linalg.svd(x, full_matrices=False)
return tuple((np.abs(u), s, np.abs(v)))
return grad(fun)(x)

k = 4
m = 5
Expand All @@ -380,7 +400,6 @@ def test_svd_only_s_2d():
def fun(x):
s = np.linalg.svd(x, full_matrices=False, compute_uv=False)
return s
return grad(fun)(x)

m = 5
n = 3
Expand All @@ -391,7 +410,6 @@ def test_svd_only_s_2d_complex():
def fun(x):
s = np.linalg.svd(x, full_matrices=False, compute_uv=False)
return s
return grad(fun)(x)

m = 5
n = 3
Expand All @@ -402,7 +420,6 @@ def test_svd_only_s_3d():
def fun(x):
s = np.linalg.svd(x, full_matrices=False, compute_uv=False)
return s
return grad(fun)(x)

k = 4
m = 5
Expand All @@ -414,7 +431,6 @@ def test_svd_only_s_3d_complex():
def fun(x):
s = np.linalg.svd(x, full_matrices=False, compute_uv=False)
return s
return grad(fun)(x)

k = 4
m = 5
Expand Down

0 comments on commit c6f630a

Please sign in to comment.