[go: nahoru, domu]

Skip to content

Commit

Permalink
Merge pull request #1 from deepmind:tomhennigan-patch-1
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 368405770
  • Loading branch information
Copybara-Service committed Apr 14, 2021
2 parents 9aed136 + aa4eeee commit ffbe9c5
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ my_policy = jmp.Policy(compute_dtype=half,

The policy object can be used to cast pytrees:

```python{highlight="lines:2,5,8"}
```python
def layer(params, x):
params, x = my_policy.cast_to_compute((params, x))
w, b = params
Expand Down Expand Up @@ -96,7 +96,7 @@ particularly important when training with `float16` and less important for
The easiest way to shift gradients is with loss scaling, which scales your loss
and gradients by `S` and `1/S` respectively.

```python{highlight="content:\bloss_scale\b"}
```python
def my_loss_fn(params, loss_scale: jmp.LossScale, ...):
loss = ...
# You should apply regularization etc before scaling.
Expand Down Expand Up @@ -127,7 +127,7 @@ during training to find the largest value for `S` that produces finite
gradients. This is more convenient and robust compared with picking a static
loss scale, but has a small performance impact (between 1 and 5%).

```python{highlight="content:\bloss_scale\b"}
```python
def my_loss_fn(params, loss_scale: jmp.LossScale, ...):
loss = ...
# You should apply regularization etc before scaling.
Expand Down

0 comments on commit ffbe9c5

Please sign in to comment.