[go: nahoru, domu]

Skip to content

Commit

Permalink
Merge pull request #579 from j-towns/norm-jvp
Browse files Browse the repository at this point in the history
Add linalg.norm jvp
  • Loading branch information
j-towns committed Jun 15, 2022
2 parents cb5f840 + a979e8c commit 2afc187
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 18 deletions.
65 changes: 56 additions & 9 deletions autograd/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy.linalg as npla
from .numpy_wrapper import wrap_namespace
from . import numpy_wrapper as anp
from autograd.extend import defvjp
from autograd.extend import defvjp, defjvp

wrap_namespace(npla.__dict__, globals())

Expand All @@ -18,8 +18,8 @@ def T(x): return anp.swapaxes(x, -1, -2)

_dot = partial(anp.einsum, '...ij,...jk->...ik')

# batched diag
_diag = lambda a: anp.eye(a.shape[-1])*a
# batched diag
_diag = lambda a: anp.eye(a.shape[-1])*a

# batched diagonal, similar to matrix_diag in tensorflow
def _matrix_diag(a):
Expand Down Expand Up @@ -56,7 +56,7 @@ def grad_solve(argnum, ans, a, b):
return lambda g: solve(T(a), g)
defvjp(solve, partial(grad_solve, 0), partial(grad_solve, 1))

def grad_norm(ans, x, ord=None, axis=None):
def norm_vjp(ans, x, ord=None, axis=None):
def check_implemented():
matrix_norm = (x.ndim == 2 and axis is None) or isinstance(axis, tuple)

Expand Down Expand Up @@ -110,20 +110,67 @@ def vjp(g):
# see https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm
return expand(g / ans**(ord-1)) * x * anp.abs(x)**(ord-2)
return vjp
defvjp(norm, grad_norm)
defvjp(norm, norm_vjp)

def norm_jvp(g, ans, x, ord=None, axis=None):
def check_implemented():
matrix_norm = (x.ndim == 2 and axis is None) or isinstance(axis, tuple)

if matrix_norm:
if not (ord is None or ord == 'fro' or ord == 'nuc'):
raise NotImplementedError('Gradient of matrix norm not '
'implemented for ord={}'.format(ord))
elif not (ord is None or ord > 1):
raise NotImplementedError('Gradient of norm not '
'implemented for ord={}'.format(ord))

if axis is None:
contract = lambda a: anp.sum(a)
else:
contract = partial(anp.sum, axis=axis)

if ord == 'nuc':
if axis is None:
roll = lambda a: a
unroll = lambda a: a
else:
row_axis, col_axis = axis
if row_axis > col_axis:
row_axis = row_axis - 1
# Roll matrix axes to the back
roll = lambda a: anp.rollaxis(anp.rollaxis(a, col_axis, a.ndim),
row_axis, a.ndim-1)
# Roll matrix axes to their original position
unroll = lambda a: anp.rollaxis(anp.rollaxis(a, a.ndim-2, row_axis),
a.ndim-1, col_axis)

check_implemented()
if ord in (None, 2, 'fro'):
return contract(g * x) / ans
elif ord == 'nuc':
x_rolled = roll(x)
u, s, vt = svd(x_rolled, full_matrices=False)
uvt_rolled = _dot(u, vt)
# Roll the matrix axes back to their correct positions
uvt = unroll(uvt_rolled)
return contract(g * uvt)
else:
# see https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm
return contract(g * x * anp.abs(x)**(ord-2)) / ans**(ord-1)
defjvp(norm, norm_jvp)

def grad_eigh(ans, x, UPLO='L'):
"""Gradient for eigenvalues and vectors of a symmetric matrix."""
N = x.shape[-1]
w, v = ans # Eigenvalues, eigenvectors.
vc = anp.conj(v)

def vjp(g):
wg, vg = g # Gradient w.r.t. eigenvalues, eigenvectors.
w_repeated = anp.repeat(w[..., anp.newaxis], N, axis=-1)

# Eigenvalue part
vjp_temp = _dot(vc * wg[..., anp.newaxis, :], T(v))
vjp_temp = _dot(vc * wg[..., anp.newaxis, :], T(v))

# Add eigenvector part only if non-zero backward signal is present.
# This can avoid NaN results for degenerate cases if the function depends
Expand All @@ -142,7 +189,7 @@ def vjp(g):
tri = anp.tile(anp.tril(anp.ones(N), -1), reps)
elif UPLO == 'U':
tri = anp.tile(anp.triu(anp.ones(N), 1), reps)

return anp.real(vjp_temp)*anp.eye(vjp_temp.shape[-1]) + \
(vjp_temp + anp.conj(T(vjp_temp))) * tri

Expand Down Expand Up @@ -230,7 +277,7 @@ def vjp(g):
vtgv = _dot(T(v), gv)
t1 = (f * (utgu - anp.conj(T(utgu)))) * s[..., anp.newaxis, :]
t1 = t1 + i * gs[..., :, anp.newaxis]
t1 = t1 + s[..., :, anp.newaxis] * (f * (vtgv - anp.conj(T(vtgv))))
t1 = t1 + s[..., :, anp.newaxis] * (f * (vtgv - anp.conj(T(vtgv))))

if anp.iscomplexobj(u):
t1 = t1 + 1j*anp.imag(_diag(utgu)) / s[..., anp.newaxis, :]
Expand Down
20 changes: 11 additions & 9 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,45 +134,47 @@ def test_vector_2norm():
def fun(x): return np.linalg.norm(x)
D = 6
vec = npr.randn(D)
check_grads(fun)(vec)
check_grads(fun, modes=['fwd', 'rev'])(vec)

def test_frobenius_norm():
def fun(x): return np.linalg.norm(x)
D = 6
mat = npr.randn(D, D-1)
check_grads(fun)(mat)
check_grads(fun, modes=['fwd', 'rev'])(mat)

def test_frobenius_norm_axis():
def fun(x): return np.linalg.norm(x, axis=(0, 1))
D = 6
mat = npr.randn(D, D-1, D-2)
check_grads(fun)(mat)
check_grads(fun, modes=['fwd', 'rev'])(mat)

@pytest.mark.parametrize("ord", range(2, 5))
@pytest.mark.parametrize("size", [6])
def test_vector_norm_ord(size, ord):
def fun(x): return np.linalg.norm(x, ord=ord)
vec = npr.randn(size)
check_grads(fun)(vec)
check_grads(fun, modes=['fwd', 'rev'])(vec)

@pytest.mark.parametrize("axis", range(3))
@pytest.mark.parametrize("shape", [(6, 5, 4)])
def test_norm_axis(shape, axis):
def fun(x): return np.linalg.norm(x, axis=axis)
arr = npr.randn(*shape)
check_grads(fun)(arr)
check_grads(fun, modes=['fwd', 'rev'])(arr)

def test_norm_nuclear():
def fun(x): return np.linalg.norm(x, ord='nuc')
D = 6
mat = npr.randn(D, D-1)
check_grads(fun)(mat)
# Order 1 because the jvp of the svd is not implemented
check_grads(fun, modes=['fwd', 'rev'], order=1)(mat)

def test_norm_nuclear_axis():
def fun(x): return np.linalg.norm(x, ord='nuc', axis=(0, 1))
D = 6
mat = npr.randn(D, D-1, D-2)
check_grads(fun)(mat)
# Order 1 because the jvp of the svd is not implemented
check_grads(fun, modes=['fwd', 'rev'], order=1)(mat)

def test_eigvalh_lower():
def fun(x):
Expand Down Expand Up @@ -210,8 +212,8 @@ def fun(x):
check_grads(fun)(hmat)

# For complex-valued matrices, the eigenvectors could have arbitrary phases (gauge)
# which makes it impossible to compare to numerical derivatives. So we take the
# absolute value to get rid of that phase.
# which makes it impossible to compare to numerical derivatives. So we take the
# absolute value to get rid of that phase.
def test_eigvalh_lower_complex():
def fun(x):
w, v = np.linalg.eigh(x)
Expand Down

0 comments on commit 2afc187

Please sign in to comment.