[go: nahoru, domu]

Skip to content

Commit

Permalink
Fixes the order in which the ranking agent passes labels and predicti…
Browse files Browse the repository at this point in the history
…ons to the loss function.

TF loss functions use the convention that the API takes the labels as the first argument and predictions as the second argument. While most loss functions are invariant to the ordering, some are not, such as the sigmoid_cross_entropy loss.

PiperOrigin-RevId: 568945812
Change-Id: I079ba05270c338339bc68d10b6485f01c9c3b85d
  • Loading branch information
TF-Agents Team authored and Copybara-Service committed Sep 27, 2023
1 parent 97d8a57 commit 498923f
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 14 deletions.
2 changes: 1 addition & 1 deletion tf_agents/bandits/agents/ranking_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def _loss(

est_reward = self._scoring_network(flat_obs, training)[0]
loss_output = self._error_loss_fn(
est_reward, score, reduction=tf.compat.v1.losses.Reduction.NONE
score, est_reward, reduction=tf.compat.v1.losses.Reduction.NONE
)
if len(list(loss_output.shape)) == 1:
# In case the loss is an aggregate over all slots, we only use one weight
Expand Down
32 changes: 19 additions & 13 deletions tf_agents/bandits/agents/ranking_agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def testTrainAgentCascadingFeedback(
'num_items': 13,
'num_slots': 11,
'non_click_score': 0,
'loss': 'softmax_cross_entropy',
'loss': 'sigmoid_cross_entropy',
},
])
def testTrainAgentScoreFeedback(
Expand All @@ -215,7 +215,7 @@ def testTrainAgentScoreFeedback(
if not tf.executing_eagerly():
self.skipTest('Only works in eager mode.')
obs_spec = bandit_spec_utils.create_per_arm_observation_spec(
global_dim, item_dim, num_items
global_dim, item_dim, num_items, add_num_actions_feature=True
)
scoring_net = (
global_and_arm_feature_network.create_feed_forward_common_tower_network(
Expand All @@ -230,7 +230,7 @@ def testTrainAgentScoreFeedback(
)
if non_click_score is not None:
with self.assertRaisesRegex(ValueError, 'Parameter `non_click_score`'):
agent = ranking_agent.RankingAgent(
_ = ranking_agent.RankingAgent(
time_step_spec=time_step_spec,
action_spec=action_spec,
scoring_network=scoring_net,
Expand All @@ -241,20 +241,14 @@ def testTrainAgentScoreFeedback(
)
non_click_score = None

def loss_fn(logits, labels, reduction):
del reduction
return tf.nn.softmax_cross_entropy_with_logits(
labels=labels, logits=logits
)

agent = ranking_agent.RankingAgent(
time_step_spec=time_step_spec,
action_spec=action_spec,
scoring_network=scoring_net,
policy_type=policy_type,
error_loss_fn=(
loss_fn
if loss == 'softmax_cross_entropy'
tf.compat.v1.losses.sigmoid_cross_entropy
if loss == 'sigmoid_cross_entropy'
else tf.compat.v1.losses.mean_squared_error
),
feedback_model=ranking_agent.FeedbackModel.SCORE_VECTOR,
Expand All @@ -269,14 +263,24 @@ def loss_fn(logits, labels, reduction):
tf.range(batch_size * num_slots * item_dim, dtype=tf.float32),
[batch_size, num_slots, item_dim],
)
num_actions = tf.constant(num_slots - 1, shape=[batch_size], dtype=tf.int32)

observations = {
bandit_spec_utils.GLOBAL_FEATURE_KEY: global_obs,
bandit_spec_utils.PER_ARM_FEATURE_KEY: item_obs,
bandit_spec_utils.NUM_ACTIONS_FEATURE_KEY: num_actions,
}
scores = tf.reshape(
tf.range(batch_size * num_slots, dtype=tf.float32),
shape=[batch_size, num_slots],
)
if loss == 'sigmoid_cross_entropy':
scores = tf.where(
tf.greater(scores, tf.reduce_mean(scores)),
tf.ones_like(scores, dtype=tf.float32),
tf.zeros_like(scores, dtype=tf.float32),
)

initial_step, final_step = _get_initial_and_final_steps(
observations, scores
)
Expand All @@ -290,8 +294,10 @@ def loss_fn(logits, labels, reduction):
),
)
experience = _get_experience(initial_step, action_step, final_step)
weights = tf.range(batch_size, dtype=tf.float32)
agent.train(experience, weights)
for i in range(10):
self.assertGreaterEqual(
agent.train(experience).loss, 0, msg=f'Train step {i}'
)

@parameterized.parameters([
{
Expand Down

0 comments on commit 498923f

Please sign in to comment.