[go: nahoru, domu]

Skip to content

Commit

Permalink
enumerate positional bias types
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 607099230
Change-Id: Ifa80cbc06f273c5a1dcafdebd5005fbce3ee21af
  • Loading branch information
TF-Agents Team authored and Copybara-Service committed Feb 14, 2024
1 parent ccff2b3 commit c2fcf32
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 46 deletions.
12 changes: 11 additions & 1 deletion tf_agents/bandits/agents/examples/v2/train_eval_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,17 @@ def _relevance_fn(global_obs, item_obs):
raise NotImplementedError(
'Policy type {} is not implemented'.format(FLAGS.policy_type)
)
positional_bias_type = FLAGS.bias_type or None
if FLAGS.positional_bias_type == 'base':
positional_bias_type = ranking_agent.PositionalBiasType.BASE
elif FLAGS.positional_bias_type == 'exponent':
positional_bias_type = ranking_agent.PositionalBiasType.EXPONENT
else:
raise NotImplementedError(
'Positional bias type {} is not implemented'.format(
FLAGS.positional_bias_type
)
)

agent = ranking_agent.RankingAgent(
time_step_spec=environment.time_step_spec(),
action_spec=environment.action_spec(),
Expand Down
38 changes: 21 additions & 17 deletions tf_agents/bandits/agents/ranking_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,19 @@ class FeedbackModel(enum.Enum):
SCORE_VECTOR = 2


class PositionalBiasType(enum.Enum):
"""Enumeration of positional bias types."""

UNSET = 0
# The bias weight for each slot position is `k^s`, where `s` is the bias
# severity and `k` is the position.
BASE = 1
# The weights are `s^k`. These bias adjustment types are inspired by Ovaisi
# et al. `Correcting for Selection Bias in Learning-to-rank Systems`
# (WWW 2020).
EXPONENT = 2


class RankingAgent(tf_agent.TFAgent):
"""Ranking agent class."""

Expand All @@ -129,7 +142,7 @@ def __init__(
error_loss_fn: types.LossFn = tf.compat.v1.losses.mean_squared_error,
feedback_model: FeedbackModel = FeedbackModel.CASCADING,
non_click_score: Optional[float] = None,
positional_bias_type: Optional[Text] = None,
positional_bias_type: PositionalBiasType = PositionalBiasType.UNSET,
positional_bias_severity: Optional[float] = None,
positional_bias_positive_only: bool = False,
logits_temperature: float = 1.0,
Expand Down Expand Up @@ -162,16 +175,9 @@ def __init__(
non_click_score: (float) For the cascading feedback model, this is the
score value for items lying "before" the clicked item. If not set, -1 is
used. It is recommended (but not enforced) to use a negative value.
positional_bias_type: (string) If not set (or set to `None`), the agent
does not apply bias adjustment. If set to either `base` or `exponent`,
it parameter determines what way the positional bias is accounted for.
`base`: The bias weight for each slot position is `k^s`, where `s` is
the bias severity (set in the next parameter), and `k` is the position.
`exponent`: The weights are `s^k`. These bias adjustment types are
inspired by Ovaisi et al. `Correcting for Selection Bias in
Learning-to-rank Systems` (WWW 2020).
positional_bias_severity: (float) The severity `s`, used as explained
above. If `positional_bias_type` is unset, this parameter has no effect.
positional_bias_type: Type of positional bias to use when training.
positional_bias_severity: (float) The severity `s`, used for the `BASE`
positional bias type.
positional_bias_positive_only: Whether to use the above defined bias
weights only for positives (that is, clicked items). If
`positional_bias_type` is unset, this parameter has no effect.
Expand Down Expand Up @@ -403,21 +409,19 @@ def _construct_sample_weights(self, reward, observation, weights):
chosen_index + 1, self._num_slots, dtype=tf.float32
)
weights = multiplier * weights
if self._positional_bias_type is not None:
if self._positional_bias_type != PositionalBiasType.UNSET:
batched_range = tf.broadcast_to(
tf.range(self._num_slots, dtype=tf.float32), tf.shape(weights)
)
if self._positional_bias_type == 'base':
if self._positional_bias_type == PositionalBiasType.BASE:
position_bias_multipliers = tf.pow(
batched_range + 1, self._positional_bias_severity
)
elif self._positional_bias_type == 'exponent':
elif self._positional_bias_type == PositionalBiasType.EXPONENT:
position_bias_multipliers = tf.pow(
self._positional_bias_severity, batched_range
)
else:
raise ValueError(
'non-existing bias type: ' + self._positional_bias_type
)
raise ValueError('non-existing positional bias type')
weights = position_bias_multipliers * weights
return weights
40 changes: 12 additions & 28 deletions tf_agents/bandits/agents/ranking_agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def testTrainAgentScoreFeedback(
'item_dim': 3,
'num_items': 10,
'num_slots': 5,
'positional_bias_type': 'base',
'positional_bias_type': ranking_agent.PositionalBiasType.BASE,
'positional_bias_severity': 1.2,
'positional_bias_positive_only': False,
},
Expand All @@ -320,7 +320,7 @@ def testTrainAgentScoreFeedback(
'item_dim': 5,
'num_items': 21,
'num_slots': 17,
'positional_bias_type': 'exponent',
'positional_bias_type': ranking_agent.PositionalBiasType.EXPONENT,
'positional_bias_severity': 1.3,
'positional_bias_positive_only': False,
},
Expand All @@ -332,19 +332,7 @@ def testTrainAgentScoreFeedback(
'item_dim': 4,
'num_items': 13,
'num_slots': 11,
'positional_bias_type': 'base',
'positional_bias_severity': 1.0,
'positional_bias_positive_only': True,
},
{
'feedback_model': ranking_agent.FeedbackModel.SCORE_VECTOR,
'policy_type': ranking_agent.RankingPolicyType.DESCENDING_SCORES,
'batch_size': 2,
'global_dim': 3,
'item_dim': 4,
'num_items': 13,
'num_slots': 11,
'positional_bias_type': 'invalid',
'positional_bias_type': ranking_agent.PositionalBiasType.BASE,
'positional_bias_severity': 1.0,
'positional_bias_positive_only': True,
},
Expand Down Expand Up @@ -435,19 +423,15 @@ def testPositionalBiasParams(
),
)
experience = _get_experience(initial_step, action_step, final_step)
if positional_bias_type == 'invalid':
with self.assertRaisesRegex(ValueError, 'non-existing bias type'):
agent.train(experience)
else:
agent.train(experience)
weights = agent._construct_sample_weights(scores, observations, None)
self.assertAllEqual(weights.shape, [batch_size, num_slots])
expected = (
2**positional_bias_severity
if positional_bias_type == 'base'
else positional_bias_severity
)
self.assertAllClose(weights[-1, 1], expected)
agent.train(experience)
weights = agent._construct_sample_weights(scores, observations, None)
self.assertAllEqual(weights.shape, [batch_size, num_slots])
expected = (
2**positional_bias_severity
if positional_bias_type == ranking_agent.PositionalBiasType.BASE
else positional_bias_severity
)
self.assertAllClose(weights[-1, 1], expected)


if __name__ == '__main__':
Expand Down

0 comments on commit c2fcf32

Please sign in to comment.