[go: nahoru, domu]

Skip to content

Commit

Permalink
Merge pull request #538 from buwantaiji/fix-svd
Browse files Browse the repository at this point in the history
Fix complex valued svd backpropagation.
  • Loading branch information
j-towns committed Sep 23, 2019
2 parents 6e5ec96 + da3989b commit 22d2c8c
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 50 deletions.
71 changes: 25 additions & 46 deletions autograd/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def T(x): return anp.swapaxes(x, -1, -2)

_dot = partial(anp.einsum, '...ij,...jk->...ik')

# batched diag
_diag = lambda a: anp.eye(a.shape[-1])*a

# add two dimensions to the end of x
def add2d(x): return anp.reshape(x, anp.shape(x) + (1, 1))

Expand Down Expand Up @@ -155,6 +158,8 @@ def vjp(g):
return vjp
defvjp(cholesky, grad_cholesky)

# https://j-towns.github.io/papers/svd-derivative.pdf
# https://arxiv.org/abs/1909.02659
def grad_svd(usv_, a, full_matrices=True, compute_uv=True):
def vjp(g):
usv = usv_
Expand All @@ -165,9 +170,9 @@ def vjp(g):
# Need U and V so do the whole svd anyway...
usv = svd(a, full_matrices=False)
u = usv[0]
v = T(usv[2])
v = anp.conj(T(usv[2]))

return _dot(u * g[..., anp.newaxis, :], T(v))
return _dot(anp.conj(u) * g[..., anp.newaxis, :], T(v))

elif full_matrices:
raise NotImplementedError(
Expand All @@ -176,7 +181,7 @@ def vjp(g):
else:
u = usv[0]
s = usv[1]
v = T(usv[2])
v = anp.conj(T(usv[2]))

m, n = a.shape[-2:]

Expand All @@ -186,61 +191,35 @@ def vjp(g):

f = 1 / (s[..., anp.newaxis, :]**2 - s[..., :, anp.newaxis]**2 + i)

if m < n:
gu = g[0]
gs = g[1]
gv = T(g[2])

utgu = _dot(T(u), gu)
vtgv = _dot(T(v), gv)
gu = g[0]
gs = g[1]
gv = anp.conj(T(g[2]))

i_minus_vvt = (anp.reshape(anp.eye(n), anp.concatenate((anp.ones(a.ndim - 2, dtype=int), (n, n)))) -
_dot(v, T(v)))
utgu = _dot(T(u), gu)
vtgv = _dot(T(v), gv)
t1 = (f * (utgu - anp.conj(T(utgu)))) * s[..., anp.newaxis, :]
t1 = t1 + i * gs[..., :, anp.newaxis]
t1 = t1 + s[..., :, anp.newaxis] * (f * (vtgv - anp.conj(T(vtgv))))

t1 = (f * (utgu - T(utgu))) * s[..., anp.newaxis, :]
t1 = t1 + i * gs[..., :, anp.newaxis]
t1 = t1 + s[..., :, anp.newaxis] * (f * (vtgv - T(vtgv)))
if anp.iscomplexobj(u):
t1 = t1 + 1j*anp.imag(_diag(utgu)) / s[..., anp.newaxis, :]

t1 = _dot(_dot(u, t1), T(v))
t1 = _dot(_dot(anp.conj(u), t1), T(v))

t1 = t1 + _dot(_dot(u / s[..., anp.newaxis, :], T(gv)), i_minus_vvt)
if m < n:
i_minus_vvt = (anp.reshape(anp.eye(n), anp.concatenate((anp.ones(a.ndim - 2, dtype=int), (n, n)))) -
_dot(v, anp.conj(T(v))))
t1 = t1 + anp.conj(_dot(_dot(u / s[..., anp.newaxis, :], T(gv)), i_minus_vvt))

return t1

elif m == n:
gu = g[0]
gs = g[1]
gv = T(g[2])

utgu = _dot(T(u), gu)
vtgv = _dot(T(v), gv)

t1 = (f * (utgu - T(utgu))) * s[..., anp.newaxis, :]
t1 = t1 + i * gs[..., :, anp.newaxis]
t1 = t1 + s[..., :, anp.newaxis] * (f * (vtgv - T(vtgv)))

t1 = _dot(_dot(u, t1), T(v))

return t1

elif m > n:
gu = g[0]
gs = g[1]
gv = T(g[2])

utgu = _dot(T(u), gu)
vtgv = _dot(T(v), gv)

i_minus_uut = (anp.reshape(anp.eye(m), anp.concatenate((anp.ones(a.ndim - 2, dtype=int), (m, m)))) -
_dot(u, T(u)))

t1 = (f * (utgu - T(utgu))) * s[..., anp.newaxis, :]
t1 = t1 + i * gs[..., :, anp.newaxis]
t1 = t1 + s[..., :, anp.newaxis] * (f * (vtgv - T(vtgv)))

t1 = _dot(_dot(u, t1), T(v))

t1 = t1 + _dot(i_minus_uut, _dot(gu, T(v) / s[..., :, anp.newaxis]))
_dot(u, anp.conj(T(u))))
t1 = t1 + T(_dot(_dot(v/s[..., anp.newaxis, :], T(gu)), i_minus_uut) )

return t1
return vjp
Expand Down
93 changes: 89 additions & 4 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,16 @@ def fun(x):
mat = npr.randn(m, n)
check_grads(fun)(mat)

def test_svd_wide_2d_complex():
def fun(x):
u, s, v = np.linalg.svd(x, full_matrices=False)
return tuple((np.abs(u), s, np.abs(v)))
return grad(fun)(x)
m = 3
n = 5
mat = npr.randn(m, n) + 1j * npr.randn(m, n)
check_grads(fun)(mat)

def test_svd_wide_3d():
def fun(x):
u, s, v = np.linalg.svd(x, full_matrices=False)
Expand All @@ -263,10 +273,21 @@ def fun(x):
k = 4
m = 3
n = 5

mat = npr.randn(k, m, n)
check_grads(fun)(mat)

def test_svd_wide_3d_complex():
def fun(x):
u, s, v = np.linalg.svd(x, full_matrices=False)
return tuple((np.abs(u), s, np.abs(v)))
return grad(fun)(x)

k = 4
m = 3
n = 5
mat = npr.randn(k, m, n) + 1j* npr.randn(k, m, n)
check_grads(fun)(mat)

def test_svd_square_2d():
def fun(x):
u, s, v = np.linalg.svd(x, full_matrices=False)
Expand All @@ -277,6 +298,16 @@ def fun(x):
mat = npr.randn(m, n)
check_grads(fun)(mat)

def test_svd_square_2d_complex():
def fun(x):
u, s, v = np.linalg.svd(x, full_matrices=False)
return tuple((np.abs(u), s, np.abs(v)))
return grad(fun)(x)
m = 4
n = 4
mat = npr.randn(m, n) + 1j * npr.randn(m, n)
check_grads(fun)(mat)

def test_svd_square_3d():
def fun(x):
u, s, v = np.linalg.svd(x, full_matrices=False)
Expand All @@ -286,10 +317,21 @@ def fun(x):
k = 3
m = 4
n = 4

mat = npr.randn(k, m, n)
check_grads(fun)(mat)

def test_svd_square_3d_complex():
def fun(x):
u, s, v = np.linalg.svd(x, full_matrices=False)
return tuple((np.abs(u), s, np.abs(v)))
return grad(fun)(x)

k = 3
m = 4
n = 4
mat = npr.randn(k, m, n) + 1j * npr.randn(k, m, n)
check_grads(fun)(mat)

def test_svd_tall_2d():
def fun(x):
u, s, v = np.linalg.svd(x, full_matrices=False)
Expand All @@ -300,6 +342,16 @@ def fun(x):
mat = npr.randn(m, n)
check_grads(fun)(mat)

def test_svd_tall_2d_complex():
def fun(x):
u, s, v = np.linalg.svd(x, full_matrices=False)
return tuple((np.abs(u), s, np.abs(v)))
return grad(fun)(x)
m = 5
n = 3
mat = npr.randn(m, n) + 1j * npr.randn(m, n)
check_grads(fun)(mat)

def test_svd_tall_3d():
def fun(x):
u, s, v = np.linalg.svd(x, full_matrices=False)
Expand All @@ -309,10 +361,21 @@ def fun(x):
k = 4
m = 5
n = 3

mat = npr.randn(k, m, n)
check_grads(fun)(mat)

def test_svd_tall_3d_complex():
def fun(x):
u, s, v = np.linalg.svd(x, full_matrices=False)
return tuple((np.abs(u), s, np.abs(v)))
return grad(fun)(x)

k = 4
m = 5
n = 3
mat = npr.randn(k, m, n) + 1j* npr.randn(k, m, n)
check_grads(fun)(mat)

def test_svd_only_s_2d():
def fun(x):
s = np.linalg.svd(x, full_matrices=False, compute_uv=False)
Expand All @@ -324,6 +387,17 @@ def fun(x):
mat = npr.randn(m, n)
check_grads(fun)(mat)

def test_svd_only_s_2d_complex():
def fun(x):
s = np.linalg.svd(x, full_matrices=False, compute_uv=False)
return s
return grad(fun)(x)

m = 5
n = 3
mat = npr.randn(m, n) + 1j* npr.randn(m, n)
check_grads(fun)(mat)

def test_svd_only_s_3d():
def fun(x):
s = np.linalg.svd(x, full_matrices=False, compute_uv=False)
Expand All @@ -333,6 +407,17 @@ def fun(x):
k = 4
m = 5
n = 3

mat = npr.randn(k, m, n)
check_grads(fun)(mat)

def test_svd_only_s_3d_complex():
def fun(x):
s = np.linalg.svd(x, full_matrices=False, compute_uv=False)
return s
return grad(fun)(x)

k = 4
m = 5
n = 3
mat = npr.randn(k, m, n) + 1j* npr.randn(k, m, n)
check_grads(fun)(mat)

0 comments on commit 22d2c8c

Please sign in to comment.