[go: nahoru, domu]

Skip to content

Commit

Permalink
Add FIXED_BIAS_WEIGHTS positional bias type to ranking_agent
Browse files Browse the repository at this point in the history
`FIXED_BIAS_WEIGHTS` is a new type where fixed positional bias weights for all slots are given as an array. These weights can be from user's pre-knowledge, learned from offline analysis (e.g. TopN randomization) and so on.

PiperOrigin-RevId: 617544784
Change-Id: I3cd891bffdc65626d3e0ba513f050827f74ce62a
  • Loading branch information
TF-Agents Team authored and Copybara-Service committed Mar 20, 2024
1 parent c2fcf32 commit c846013
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 22 deletions.
13 changes: 12 additions & 1 deletion tf_agents/bandits/agents/examples/v2/train_eval_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,19 @@
'bias_type',
'',
'Whether the agent models the positional '
'bias with the basis or the exponent changes. If unset, the'
'bias with the basis, the exponent or fixed bias weights. If unset, the'
' agent applies no positional bias.',
)
flags.DEFINE_float(
'bias_severity', 1.0, 'The severity of the bias adjustment by the agent.'
)
flags.DEFINE_list(
'bias_weights',
[],
'The positional bias weights. For FIXED_BIAS_WEIGHTS type, the agent will'
' use these weights to adjust the rewards. The length of the list must be'
' equal to the number of slots.',
)
flags.DEFINE_bool(
'bias_positive_only',
False,
Expand Down Expand Up @@ -174,12 +181,15 @@ def _relevance_fn(global_obs, item_obs):
positional_bias_type = ranking_agent.PositionalBiasType.BASE
elif FLAGS.positional_bias_type == 'exponent':
positional_bias_type = ranking_agent.PositionalBiasType.EXPONENT
elif FLAGS.positional_bias_type == 'fixed_bias_weights':
positional_bias_type = ranking_agent.PositionalBiasType.FIXED_BIAS_WEIGHTS
else:
raise NotImplementedError(
'Positional bias type {} is not implemented'.format(
FLAGS.positional_bias_type
)
)
positional_bias_weights = [float(w) for w in FLAGS.positional_bias_weights]

agent = ranking_agent.RankingAgent(
time_step_spec=environment.time_step_spec(),
Expand All @@ -190,6 +200,7 @@ def _relevance_fn(global_obs, item_obs):
feedback_model=feedback_model,
positional_bias_type=positional_bias_type,
positional_bias_severity=FLAGS.bias_severity,
positional_bias_weights=positional_bias_weights,
positional_bias_positive_only=FLAGS.bias_positive_only,
summarize_grads_and_vars=True,
)
Expand Down
61 changes: 46 additions & 15 deletions tf_agents/bandits/agents/ranking_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@
recommendation. The user is responsible for converting the observation to the
syntax required by the agent.
"""

import enum
from typing import Optional, Text
from typing import List, Optional, Text

import tensorflow as tf # pylint: disable=g-explicit-tensorflow-version-import
from tf_agents.agents import tf_agent
Expand Down Expand Up @@ -127,6 +128,9 @@ class PositionalBiasType(enum.Enum):
# et al. `Correcting for Selection Bias in Learning-to-rank Systems`
# (WWW 2020).
EXPONENT = 2
# The bias weight for each slot position is `bias_weights[k]`, where
# `bias_weights` is the given bias weight array and `k` is the position.
FIXED_BIAS_WEIGHTS = 3


class RankingAgent(tf_agent.TFAgent):
Expand All @@ -144,6 +148,7 @@ def __init__(
non_click_score: Optional[float] = None,
positional_bias_type: PositionalBiasType = PositionalBiasType.UNSET,
positional_bias_severity: Optional[float] = None,
positional_bias_weights: Optional[List[float]] = None,
positional_bias_positive_only: bool = False,
logits_temperature: float = 1.0,
summarize_grads_and_vars: bool = False,
Expand Down Expand Up @@ -178,6 +183,8 @@ def __init__(
positional_bias_type: Type of positional bias to use when training.
positional_bias_severity: (float) The severity `s`, used for the `BASE`
positional bias type.
positional_bias_weights: (float array) The positional bias weight for each
slot position.
positional_bias_positive_only: Whether to use the above defined bias
weights only for positives (that is, clicked items). If
`positional_bias_type` is unset, this parameter has no effect.
Expand Down Expand Up @@ -230,6 +237,22 @@ def __init__(
)
self._positional_bias_type = positional_bias_type
self._positional_bias_severity = positional_bias_severity
# Validate positional_bias_weights for FIXED_BIAS_WEIGHTS PositionalBiasType
if self._positional_bias_type == PositionalBiasType.FIXED_BIAS_WEIGHTS:
if positional_bias_weights is None:
raise ValueError(
'positional_bias_weights is None but should never be for'
' FIXED_BIAS_WEIGHTS PositionalBiasType.'
)
elif len(positional_bias_weights) != self._num_slots:
raise ValueError(
'The length of positional_bias_weights should be the same as the'
' number of slots. The length of positional_bias_weights is {} and'
' the number of slots is {}.'.format(
len(positional_bias_weights), self._num_slots
)
)
self._positional_bias_weights = positional_bias_weights
self._positional_bias_positive_only = positional_bias_positive_only
if policy_type == RankingPolicyType.UNKNOWN:
policy_type = RankingPolicyType.COSINE_DISTANCE
Expand Down Expand Up @@ -409,19 +432,27 @@ def _construct_sample_weights(self, reward, observation, weights):
chosen_index + 1, self._num_slots, dtype=tf.float32
)
weights = multiplier * weights
if self._positional_bias_type != PositionalBiasType.UNSET:
batched_range = tf.broadcast_to(
tf.range(self._num_slots, dtype=tf.float32), tf.shape(weights)

if self._positional_bias_type == PositionalBiasType.UNSET:
return weights

batched_range = tf.broadcast_to(
tf.range(self._num_slots, dtype=tf.float32), tf.shape(weights)
)
if self._positional_bias_type == PositionalBiasType.BASE:
position_bias_multipliers = tf.pow(
batched_range + 1, self._positional_bias_severity
)
if self._positional_bias_type == PositionalBiasType.BASE:
position_bias_multipliers = tf.pow(
batched_range + 1, self._positional_bias_severity
)
elif self._positional_bias_type == PositionalBiasType.EXPONENT:
position_bias_multipliers = tf.pow(
self._positional_bias_severity, batched_range
)
else:
raise ValueError('non-existing positional bias type')
weights = position_bias_multipliers * weights
elif self._positional_bias_type == PositionalBiasType.EXPONENT:
position_bias_multipliers = tf.pow(
self._positional_bias_severity, batched_range
)
elif self._positional_bias_type == PositionalBiasType.FIXED_BIAS_WEIGHTS:
position_bias_multipliers = tf.tile(
tf.expand_dims(self._positional_bias_weights, axis=0),
[batch_size, 1],
)
else:
raise ValueError('non-existing positional bias type')
weights = position_bias_multipliers * weights
return weights
44 changes: 38 additions & 6 deletions tf_agents/bandits/agents/ranking_agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,8 @@ def testTrainAgentScoreFeedback(
'positional_bias_type': ranking_agent.PositionalBiasType.BASE,
'positional_bias_severity': 1.2,
'positional_bias_positive_only': False,
'positional_bias_weights': None,
'expected_second_weight': 2.2974, # 2**positional_bias_severity
},
{
'feedback_model': ranking_agent.FeedbackModel.SCORE_VECTOR,
Expand All @@ -323,6 +325,8 @@ def testTrainAgentScoreFeedback(
'positional_bias_type': ranking_agent.PositionalBiasType.EXPONENT,
'positional_bias_severity': 1.3,
'positional_bias_positive_only': False,
'positional_bias_weights': None,
'expected_second_weight': 1.3, # positional_bias_severity
},
{
'feedback_model': ranking_agent.FeedbackModel.SCORE_VECTOR,
Expand All @@ -335,6 +339,36 @@ def testTrainAgentScoreFeedback(
'positional_bias_type': ranking_agent.PositionalBiasType.BASE,
'positional_bias_severity': 1.0,
'positional_bias_positive_only': True,
'positional_bias_weights': None,
'expected_second_weight': 2.0, # 2**positional_bias_severity
},
{
'feedback_model': ranking_agent.FeedbackModel.SCORE_VECTOR,
'policy_type': ranking_agent.RankingPolicyType.DESCENDING_SCORES,
'batch_size': 2,
'global_dim': 3,
'item_dim': 4,
'num_items': 13,
'num_slots': 11,
'positional_bias_type': (
ranking_agent.PositionalBiasType.FIXED_BIAS_WEIGHTS
),
'positional_bias_severity': None,
'positional_bias_positive_only': True,
'positional_bias_weights': [
0.1,
0.2,
0.3,
0.4,
0.5,
0.6,
0.7,
0.8,
0.9,
1.0,
1.1,
],
'expected_second_weight': 0.2, # positional_bias_weights[1]
},
])
def testPositionalBiasParams(
Expand All @@ -349,6 +383,8 @@ def testPositionalBiasParams(
positional_bias_type,
positional_bias_severity,
positional_bias_positive_only,
positional_bias_weights,
expected_second_weight,
):
if not tf.executing_eagerly():
self.skipTest('Only works in eager mode.')
Expand Down Expand Up @@ -386,6 +422,7 @@ def testPositionalBiasParams(
positional_bias_type=positional_bias_type,
positional_bias_severity=positional_bias_severity,
positional_bias_positive_only=positional_bias_positive_only,
positional_bias_weights=positional_bias_weights,
optimizer=optimizer,
)
global_obs = tf.reshape(
Expand Down Expand Up @@ -426,12 +463,7 @@ def testPositionalBiasParams(
agent.train(experience)
weights = agent._construct_sample_weights(scores, observations, None)
self.assertAllEqual(weights.shape, [batch_size, num_slots])
expected = (
2**positional_bias_severity
if positional_bias_type == ranking_agent.PositionalBiasType.BASE
else positional_bias_severity
)
self.assertAllClose(weights[-1, 1], expected)
self.assertAllClose(weights[-1, 1], expected_second_weight)


if __name__ == '__main__':
Expand Down

0 comments on commit c846013

Please sign in to comment.