[go: nahoru, domu]

Skip to content

Commit

Permalink
Minor optimizations for ranking agents.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 482127141
Change-Id: I2c87cbf1b5a51631def91868f80cbdc356b75ad4
  • Loading branch information
bartokg authored and Copybara-Service committed Oct 19, 2022
1 parent 3a4c0f2 commit 868f276
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 25 deletions.
58 changes: 38 additions & 20 deletions tf_agents/bandits/policies/ranking_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,45 +40,46 @@ class PenalizedPlackettLuce(tfd.PlackettLuce):
def __init__(self,
features: types.Tensor,
num_slots: int,
scores: types.Tensor,
logits: types.Tensor,
penalty_mixture_coefficient: float = 1.):
"""Initializes an instance of PenalizedPlackettLuce.
Args:
features: Item features based on which similarity is calculated.
num_slots: The number of slots to fill: this many items will be sampled.
scores: Scores for the PlackettLuce distribution. Shape is `[num_items]`.
logits: Unnormalized log probabilities for the PlackettLuce distribution.
Shape is `[num_items]`.
penalty_mixture_coefficient: A parameter responsible for the balance
between selecting high scoring items and enforcing diverisity.
"""
self._features = features
self._num_slots = num_slots
self._penalty_mixture_coefficient = penalty_mixture_coefficient
super(PenalizedPlackettLuce, self).__init__(scores)
super(PenalizedPlackettLuce, self).__init__(scores=logits)

def _penalizer_fn(self, scores: types.Float, features: types.Float,
def _penalizer_fn(self, logits: types.Float, features: types.Float,
slots: Sequence[types.Int]):
"""Downscores items by their similarity to already selected items.
Args:
scores: The current scores of all items.
logits: The current logits of all items.
features: the feature vectors of the items.
slots: list of indices of already selected items.
Returns:
New scores.
New logits.
"""
raise NotImplementedError()

def _sample_n(self, n, seed=None):
scores = tf.convert_to_tensor(self.scores)
sample_shape = tf.concat([[n], tf.shape(scores)], axis=0)
logits = tf.convert_to_tensor(self.scores)
sample_shape = tf.concat([[n], tf.shape(logits)], axis=0)
slots = []
for _ in range(self._num_slots):
items = tfd.Categorical(logits=scores).sample()
items = tfd.Categorical(logits=logits).sample()
slots.append(items)
scores -= tf.one_hot(items, sample_shape[-1], on_value=np.inf)
scores = self._penalizer_fn(scores, self._features, slots)
logits -= tf.one_hot(items, sample_shape[-1], on_value=np.inf)
logits = self._penalizer_fn(logits, self._features, slots)
sample = tf.expand_dims(tf.stack(slots, axis=-1), axis=0)
return sample

Expand All @@ -89,8 +90,8 @@ def _event_shape(self, scores=None):
class CosinePenalizedPlackettLuce(PenalizedPlackettLuce):
"""A distribution that samples items based on scores and cosine similarity."""

def _penalizer_fn(self, scores, features, slots):
num_items = scores.shape[-1]
def _penalizer_fn(self, logits, features, slots):
num_items = logits.shape[-1]
num_slotted = len(slots)
slot_tensor = tf.stack(slots, axis=-1)
# The tfd.Categorical distribution will give the sample `num_items` if all
Expand All @@ -109,16 +110,34 @@ def _penalizer_fn(self, scores, features, slots):

sim_matrix = tf.reshape(all_sims, shape=[-1, num_items, num_slotted])
similarity_boosts = tf.reduce_min(sim_matrix, axis=-1)
adjusted_scores = scores + (
adjusted_logits = logits + (
self._penalty_mixture_coefficient * similarity_boosts)
return adjusted_scores
return adjusted_logits


class NoPenaltyPlackettLuce(PenalizedPlackettLuce):
class NoPenaltyPlackettLuce(tfd.PlackettLuce):
"""Identical to PlackettLuce, with input signature modified to our needs."""

def _penalizer_fn(self, scores, features, slots):
return scores
def __init__(self,
features: types.Tensor,
num_slots: int,
logits: types.Tensor,
penalty_mixture_coefficient: float = 1.):
"""Initializes an instance of NoPenaltyPlackettLuce.
Args:
features: Unused for this distribution.
num_slots: The number of slots to fill: this many items will be sampled.
logits: Unnormalized log probabilities for the PlackettLuce distribution.
Shape is `[num_items]`.
penalty_mixture_coefficient: Unused for this distribution.
"""
self._num_slots = num_slots
super(NoPenaltyPlackettLuce, self).__init__(scores=tf.math.exp(logits))

def sample(self, sample_shape=(), seed=None, name='sample', **kwargs):
return super(NoPenaltyPlackettLuce, self).sample(
sample_shape, seed, name, **kwargs)[:, :self._num_slots]


class RankingPolicy(tf_policy.TFPolicy):
Expand Down Expand Up @@ -273,8 +292,7 @@ def __init__(self, unused_features: types.Tensor, num_slots: int,
self._num_slots = num_slots

def sample(self, shape=(), seed=None):
sorted_arms = tf.argsort(self._scores, direction='DESCENDING')
return sorted_arms[:, :self._num_slots]
return tf.math.top_k(self._scores, k=self._num_slots).indices


class DescendingScoreRankingPolicy(RankingPolicy):
Expand Down
38 changes: 33 additions & 5 deletions tf_agents/bandits/policies/ranking_policy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,45 @@

class RankingPolicyTest(test_utils.TestCase, parameterized.TestCase):

@parameterized.parameters(dict(batch_size=1, num_items=20, num_slots=5),
dict(batch_size=3, num_items=15, num_slots=15),
dict(batch_size=30, num_items=115, num_slots=100))
def testPolicy(self, batch_size, num_items, num_slots):
@parameterized.parameters(
dict(
policy_class=ranking_policy.DescendingScoreRankingPolicy,
batch_size=1,
num_items=20,
num_slots=5),
dict(
policy_class=ranking_policy.DescendingScoreRankingPolicy,
batch_size=3,
num_items=15,
num_slots=15),
dict(
policy_class=ranking_policy.PenalizeCosineDistanceRankingPolicy,
batch_size=30,
num_items=115,
num_slots=100),
dict(
policy_class=ranking_policy.PenalizeCosineDistanceRankingPolicy,
batch_size=1,
num_items=20,
num_slots=5),
dict(
policy_class=ranking_policy.NoPenaltyRankingPolicy,
batch_size=3,
num_items=15,
num_slots=15),
dict(
policy_class=ranking_policy.NoPenaltyRankingPolicy,
batch_size=30,
num_items=115,
num_slots=100))
def testPolicy(self, policy_class, batch_size, num_items, num_slots):
obs_spec = bandit_spec_utils.create_per_arm_observation_spec(
7, 5, num_items)
time_step_spec = ts.time_step_spec(obs_spec)
network = arm_net.create_feed_forward_common_tower_network(
obs_spec, [3], [4], [5])

policy = ranking_policy.PenalizeCosineDistanceRankingPolicy(
policy = policy_class(
num_items=num_items,
num_slots=num_slots,
time_step_spec=time_step_spec,
Expand Down

0 comments on commit 868f276

Please sign in to comment.