[go: nahoru, domu]

Skip to content

Commit

Permalink
Penalty mixture coefficient parameter in the ranking agent.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 462363703
Change-Id: If192bdf263dd9802732838dc83802d7e2ac29546
  • Loading branch information
bartokg authored and Copybara-Service committed Jul 21, 2022
1 parent 5c38871 commit d74a620
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tf_agents/bandits/agents/ranking_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(
summarize_grads_and_vars: bool = False,
enable_summaries: bool = True,
train_step_counter: Optional[tf.Variable] = None,
penalty_mixture_coefficient: float = 1.,
name: Optional[Text] = None):
"""Initializes an instance of RankingAgent.
Expand Down Expand Up @@ -163,6 +164,9 @@ def __init__(
(debug or otherwise) should not be written.
train_step_counter: An optional `tf.Variable` to increment every time the
train op is run. Defaults to the `global_step`.
penalty_mixture_coefficient: A parameter responsible for the balance
between selecting high scoring items and enforcing diverisity. Used Only
by diversity-based policies.
name: The name of this agent instance.
"""
tf.Module.__init__(self, name=name)
Expand Down Expand Up @@ -201,6 +205,7 @@ def __init__(
self._num_slots,
time_step_spec,
scoring_network,
penalty_mixture_coefficient=penalty_mixture_coefficient,
logits_temperature=logits_temperature)
elif policy_type == RankingPolicyType.NO_PENALTY:
policy = ranking_policy.NoPenaltyRankingPolicy(
Expand Down

0 comments on commit d74a620

Please sign in to comment.