[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branch 171672655 #13606

Closed
wants to merge 66 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
bbfef93
Convert shape to TensorShape when creating _VariableFromResource
akshayka Oct 6, 2017
eb1a0a5
(1) Adds broadcasting to scaled_softplus
tensorflower-gardener Oct 6, 2017
e744cca
Changes Relu6Grad to depend on relu6's output rather than its input, …
tensorflower-gardener Oct 6, 2017
25e6d23
Adds helpers for bucketing strategies for TF monitoring samplers.
vinuraja Oct 6, 2017
c5f715f
Update ops-related pbtxt files.
tensorflower-gardener Oct 6, 2017
710efee
Bump min graph consumer version when adding functions to it
iganichev Oct 6, 2017
a713e49
Go: Update generated wrapper functions for TensorFlow ops.
tensorflower-gardener Oct 6, 2017
d9a969c
Disable some tests on tsan.
gunan Oct 6, 2017
6fc7de9
Define object-oriented metrics classes that are Eager-safe.
tensorflower-gardener Oct 6, 2017
c26542c
[XLA] Don't clone and throw away instructions without calling DetachF…
Oct 7, 2017
fb3c68d
Disable keras:models_test in tsan mode.
gunan Oct 7, 2017
646db3e
eager: Compute num_gpus() correctly.
asimshankar Oct 7, 2017
96d276f
Improvements and fixes in VirtualPlacer:
yacoder Oct 7, 2017
010dd39
Disable predict_test under tsan.
gunan Oct 7, 2017
5a107a9
Fix broken docs links to other TensorFlow interfaces in tf.contrib.le…
nealwu Oct 7, 2017
394e560
Add a custom estimator example to the regression cookbook.
MarkDaoust Oct 7, 2017
f8f1cce
Log in executor when a synchronous node is finished.
reedwm Oct 7, 2017
8433946
Make name scopes consistent.
tensorflower-gardener Oct 7, 2017
54b8c7b
Mirror SQLite zip file
jart Oct 8, 2017
3431602
Disable kmeans test in tsan.
gunan Oct 8, 2017
cab4f6f
Improve invalid size vocab ValueError by appending the vocab file.
tensorflower-gardener Oct 8, 2017
e0924e0
[TFXLA] Don't discard status unless it is NotFound.
jpienaar Oct 8, 2017
21da236
Disable flaky cluster_function_library_runtime_test in opensource.
gunan Oct 9, 2017
bb789ad
[TF:XLA] Rename HloOpcode::kLogicalX to kX
tensorflower-gardener Oct 9, 2017
edfb9bb
Correct documentation typo.
asimshankar Oct 9, 2017
b0b92fd
[tf.data] Add new custom transformation: `tf.contrib.data.scan()`.
mrry Oct 9, 2017
4878a28
Update ops-related pbtxt files.
tensorflower-gardener Oct 9, 2017
7e2b50d
Update docs of MomentumOptimizer about use_nesterov and of RMSProp about
tensorflower-gardener Oct 9, 2017
5bba158
Print numpy value for variables when in Eager mode
akshayka Oct 9, 2017
ff80191
Estimator.predict should not generate warning if user uses TF dataset.
Oct 9, 2017
15dd5fd
Track persistent memory in constant op.
Oct 9, 2017
e56628b
[TF:XLA] Rename ComputationBuilder::LogicalX to X
tensorflower-gardener Oct 9, 2017
4a97a82
Validate input shapes for the graph_callable decorator
akshayka Oct 9, 2017
8ed8e22
Make ops_test.py work with the C API enabled.
skye Oct 9, 2017
1ba562a
Rewrote the clip_by_norm op to avoid generating infinite intermediate…
benoitsteiner Oct 9, 2017
27df639
[Grappler] Correctly replace control-dependency uses.
Oct 9, 2017
11c123b
[TF:XLA] Rename HLO visitor methods from LogicalX to X
tensorflower-gardener Oct 9, 2017
0ac688a
Adding a binary classification example
tensorflower-gardener Oct 9, 2017
7e4e336
Relanding change to add config to enable S3 file system support.
Oct 9, 2017
7c74d2f
Expose tfe.test, tfe.in_eager_mode, tfe.in_graph_mode
alextp Oct 9, 2017
be69f13
[TF:XLA] Fix broken build of xla_interpreter_device.
hawkinsp Oct 9, 2017
33d5512
[Grappler] Fixed two bugs in ArithmeticOptimizer.
Oct 9, 2017
8814502
Removing side outputs from tape code.
alextp Oct 9, 2017
f49f6cd
Replace CHECK() with a WARNING in StepStatsCollector so that Save aft…
tensorflower-gardener Oct 9, 2017
0cbd8c7
New CUDA kernel for LSTMBlockCell's forward propagation.
tensorflower-gardener Oct 10, 2017
319d823
TFE: Fix reference counts when copying to Numpy arrays.
allenlavoie Oct 10, 2017
3a52d39
New CUDA kernel for LSTMBlockCell's forward propagation.
tensorflower-gardener Oct 10, 2017
fdb2b12
TFE: Fix reference counts when copying to Numpy arrays.
allenlavoie Oct 10, 2017
8ff5070
[Grappler] Optimize bitcasts.
Oct 10, 2017
319a359
Create a cuda9 cudnn 7 docker file, simpler, using ARGS.
gunan Oct 10, 2017
52d3a84
Fix wasserstein gradient penalty name scope issue and add the proper …
tensorflower-gardener Oct 10, 2017
485cb17
Fix the example in the RNN tutorial which left out one of the pieces …
nealwu Oct 10, 2017
07d78dd
Removes the use of tf.cond in the SweepHook used in the WALSMatrixFac…
tensorflower-gardener Oct 10, 2017
2cdd064
Make error message more explicit when running FusedConv2DBiasActivati…
tensorflower-gardener Oct 10, 2017
cd37dbb
Benchmark for LSTMBlockCell's forward propagation.
tensorflower-gardener Oct 10, 2017
103d383
Add scaled_softplus to the documented symbols so it can be accessed a…
tensorflower-gardener Oct 10, 2017
d08cb10
Scheduler exports tensor size info to RunMetadata. In addition, tenso…
tensorflower-gardener Oct 10, 2017
403e510
[XLA] Factor out repeated LatestNonGteAncestorAndIndex helper.
Oct 10, 2017
84f1b90
[XLA:LLVM] Rename ops.h to tuple_ops.h.
Oct 10, 2017
d98519b
[XLA:CPU] Let the elementwise concat op handle being emitted into a d…
Oct 10, 2017
4f102ff
Cache last zero tensor in eager gradient computation
iganichev Oct 10, 2017
effb22e
Use an external constant pool to reduce LLVM compile times
Oct 10, 2017
1be36dd
[TF:XLA] Re-enable strided slice tests that now pass.
hawkinsp Oct 10, 2017
90f257e
Fix ReshapeMover bug with reshaped constants; add HloVerifiedTestBase.
tensorflower-gardener Oct 10, 2017
5a26d1e
Minor cleanup (remove unused inclusions, NULL => nullptr)
tensorflower-gardener Oct 10, 2017
07bf1d3
Merge commit for internal changes
caisq Oct 10, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix wasserstein gradient penalty name scope issue and add the proper …
…name scope.

PiperOrigin-RevId: 171610946
  • Loading branch information
tensorflower-gardener committed Oct 10, 2017
commit 52d3a842463d11990600bb65f9752b59f6d8f418
83 changes: 42 additions & 41 deletions tensorflow/contrib/gan/python/losses/python/losses_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,6 @@ def acgan_generator_loss(
# GANs` (https://arxiv.org/abs/1704.00028).


# TODO(joelshor): Figure out why this function can't be inside a name scope.
def wasserstein_gradient_penalty(
real_data,
generated_data,
Expand Down Expand Up @@ -339,48 +338,50 @@ def wasserstein_gradient_penalty(
Raises:
ValueError: If the rank of data Tensors is unknown.
"""
real_data = ops.convert_to_tensor(real_data)
generated_data = ops.convert_to_tensor(generated_data)
if real_data.shape.ndims is None:
raise ValueError('`real_data` can\'t have unknown rank.')
if generated_data.shape.ndims is None:
raise ValueError('`generated_data` can\'t have unknown rank.')

differences = generated_data - real_data
batch_size = differences.shape[0].value or array_ops.shape(differences)[0]
alpha_shape = [batch_size] + [1] * (differences.shape.ndims - 1)
alpha = random_ops.random_uniform(shape=alpha_shape)
interpolates = real_data + (alpha * differences)

# Reuse variables if a discriminator scope already exists.
reuse = False if discriminator_scope is None else True
with variable_scope.variable_scope(discriminator_scope, 'gpenalty_dscope',
reuse=reuse):
disc_interpolates = discriminator_fn(interpolates, generator_inputs)

if isinstance(disc_interpolates, tuple):
# ACGAN case: disc outputs more than one tensor
disc_interpolates = disc_interpolates[0]

gradients = gradients_impl.gradients(disc_interpolates, interpolates)[0]
gradient_squares = math_ops.reduce_sum(
math_ops.square(gradients), axis=list(range(1, gradients.shape.ndims)))
# Propagate shape information, if possible.
if isinstance(batch_size, int):
gradient_squares.set_shape([
batch_size] + gradient_squares.shape.as_list()[1:])
# For numerical stability, add epsilon to the sum before taking the square
# root. Note tf.norm does not add epsilon.
slopes = math_ops.sqrt(gradient_squares + epsilon)
penalties = math_ops.square(slopes - 1.0)
penalty = losses.compute_weighted_loss(
penalties, weights, scope=scope, loss_collection=loss_collection,
reduction=reduction)
with ops.name_scope(scope, 'wasserstein_gradient_penalty',
(real_data, generated_data)) as scope:
real_data = ops.convert_to_tensor(real_data)
generated_data = ops.convert_to_tensor(generated_data)
if real_data.shape.ndims is None:
raise ValueError('`real_data` can\'t have unknown rank.')
if generated_data.shape.ndims is None:
raise ValueError('`generated_data` can\'t have unknown rank.')

differences = generated_data - real_data
batch_size = differences.shape[0].value or array_ops.shape(differences)[0]
alpha_shape = [batch_size] + [1] * (differences.shape.ndims - 1)
alpha = random_ops.random_uniform(shape=alpha_shape)
interpolates = real_data + (alpha * differences)

with ops.name_scope(None): # Clear scope so update ops are added properly.
# Reuse variables if variables already exists.
with variable_scope.variable_scope(discriminator_scope, 'gpenalty_dscope',
reuse=variable_scope.AUTO_REUSE):
disc_interpolates = discriminator_fn(interpolates, generator_inputs)

if isinstance(disc_interpolates, tuple):
# ACGAN case: disc outputs more than one tensor
disc_interpolates = disc_interpolates[0]

gradients = gradients_impl.gradients(disc_interpolates, interpolates)[0]
gradient_squares = math_ops.reduce_sum(
math_ops.square(gradients), axis=list(range(1, gradients.shape.ndims)))
# Propagate shape information, if possible.
if isinstance(batch_size, int):
gradient_squares.set_shape([
batch_size] + gradient_squares.shape.as_list()[1:])
# For numerical stability, add epsilon to the sum before taking the square
# root. Note tf.norm does not add epsilon.
slopes = math_ops.sqrt(gradient_squares + epsilon)
penalties = math_ops.square(slopes - 1.0)
penalty = losses.compute_weighted_loss(
penalties, weights, scope=scope, loss_collection=loss_collection,
reduction=reduction)

if add_summaries:
summary.scalar('gradient_penalty_loss', penalty)
if add_summaries:
summary.scalar('gradient_penalty_loss', penalty)

return penalty
return penalty


# Original losses from `Generative Adversarial Nets`
Expand Down
23 changes: 22 additions & 1 deletion tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,10 +453,11 @@ def setUp(self):
'discriminator_scope': self._scope,
}
self._expected_loss = 9.00000
self._expected_op_name = 'weighted_loss/value'
self._expected_op_name = 'wasserstein_gradient_penalty/value'
self._batch_size = 1

def _discriminator_fn(self, inputs, _):
ops.add_to_collection('fake_update_ops', constant_op.constant(1.0))
return variable_scope.get_variable('dummy_d', initializer=2.0) * inputs

def test_loss_with_placeholder(self):
Expand Down Expand Up @@ -487,6 +488,26 @@ def test_reuses_scope(self):
self.assertEqual(
num_vars, len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))

def test_works_with_get_collection(self):
"""Tests that gradient penalty works inside other scopes."""
# We ran the discriminator once in the setup, so there should be an op
# already in the collection.
self.assertEqual(1, len(ops.get_collection(
'fake_update_ops', self._kwargs['discriminator_scope'].name)))

# Make sure the op is added to the collection even if it's in a name scope.
with ops.name_scope('loss'):
tfgan_losses.wasserstein_gradient_penalty(**self._kwargs)
self.assertEqual(2, len(ops.get_collection(
'fake_update_ops', self._kwargs['discriminator_scope'].name)))

# Make sure the op is added to the collection even if it's in a variable
# scope.
with variable_scope.variable_scope('loss_vscope'):
tfgan_losses.wasserstein_gradient_penalty(**self._kwargs)
self.assertEqual(3, len(ops.get_collection(
'fake_update_ops', self._kwargs['discriminator_scope'].name)))


class MutualInformationPenaltyTest(test.TestCase, _PenaltyTest):
"""Tests for mutual_information_penalty."""
Expand Down