[go: nahoru, domu]

Skip to content

Commit

Permalink
Score vector feedback model in the ranking environment.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 466327512
Change-Id: Ie25bf517d82834bd3cf50258862b6afc030fb30f
  • Loading branch information
bartokg authored and Copybara-Service committed Aug 9, 2022
1 parent b168f39 commit 4e3d0d0
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 8 deletions.
21 changes: 17 additions & 4 deletions tf_agents/bandits/environments/ranking_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ class FeedbackModel(object):
UNKNOWN = 0
# Cascading feedback model: A tuple of the chosen index and its value.
CASCADING = 1
# Score Vector feedback model: Every element in the output ranking receives a
# score value.
SCORE_VECTOR = 2


class ClickModel(object):
Expand Down Expand Up @@ -168,6 +171,9 @@ def __init__(self,
array_spec.ArraySpec(
shape=[], dtype=np.float32, name='chosen_value')
}
elif feedback_model == FeedbackModel.SCORE_VECTOR:
reward_spec = array_spec.ArraySpec(
shape=[num_slots], dtype=np.float32, name='score_vector')
else:
raise NotImplementedError(
'Feedback model {} not implemented'.format(feedback_model))
Expand Down Expand Up @@ -209,11 +215,18 @@ def _apply_action(self, action: np.ndarray) -> types.Array:
self._click_model))

if self._feedback_model == FeedbackModel.CASCADING:
chosen_items = np.array(
chosen_items, dtype=self._reward_spec['chosen_index'].dtype)
chosen_values = (chosen_items < self._num_slots).astype(
self._reward_spec['chosen_value'].dtype)
chosen_items = np.array(chosen_items, dtype=np.float32)
chosen_values = (chosen_items < self._num_slots).astype(np.float32)
return {'chosen_index': chosen_items, 'chosen_value': chosen_values}
elif self._feedback_model == FeedbackModel.SCORE_VECTOR:
chosen_values = (chosen_items < self._num_slots).astype(np.float32)
return self._cascading_to_scorevector(chosen_items, chosen_values)

def _cascading_to_scorevector(self, chosen_items, chosen_values):
scores = np.zeros((self.batch_size, self._num_slots + 1), dtype=np.float32)
r = np.arange(self.batch_size)
scores[r, chosen_items] = chosen_values
return scores[:, :-1] # The last column is for samples with no click.

def _step(self, action):
"""We need to override this function because the reward dtype can be int."""
Expand Down
52 changes: 48 additions & 4 deletions tf_agents/bandits/environments/ranking_environment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,20 @@ class RankingPyEnvironmentTest(tf.test.TestCase, parameterized.TestCase):
'item_dim': 5,
'num_items': 7,
'num_slots': 5,
'feedback_model': ranking_environment.FeedbackModel.CASCADING,
'click_model': ranking_environment.ClickModel.GHOST_ACTIONS
}, {
'batch_size': 8,
'global_dim': 12,
'item_dim': 4,
'num_items': 23,
'num_slots': 9,
'feedback_model': ranking_environment.FeedbackModel.SCORE_VECTOR,
'click_model': ranking_environment.ClickModel.DISTANCE_BASED
}])
def test_ranking_environment(self, batch_size, global_dim, item_dim,
num_items, num_slots, click_model):
num_items, num_slots, feedback_model,
click_model):

def _global_sampling_fn():
return np.random.randint(-10, 10, [global_dim])
Expand All @@ -82,6 +85,7 @@ def _item_sampling_fn():
num_items=num_items,
num_slots=num_slots,
scores_weight_matrix=scores_weight_matrix,
feedback_model=feedback_model,
click_model=click_model,
distance_threshold=10.0,
batch_size=batch_size)
Expand All @@ -104,9 +108,49 @@ def _item_sampling_fn():
self.assertAllGreaterEqual(action, 0)
time_step = env.step(action)
reward = time_step.reward
self.assertAllEqual(reward['chosen_index'].shape, [batch_size])
self.assertAllGreaterEqual(reward['chosen_index'], 0)
self.assertAllEqual(reward['chosen_value'].shape, [batch_size])
if feedback_model == ranking_environment.FeedbackModel.CASCADING:
self.assertAllEqual(reward['chosen_index'].shape, [batch_size])
self.assertAllGreaterEqual(reward['chosen_index'], 0)
self.assertAllEqual(reward['chosen_value'].shape, [batch_size])
else:
self.assertAllEqual(reward.shape, [batch_size, num_slots])

def test_cascading_to_scorevector(self):
batch_size = 5
global_dim = 12
item_dim = 4
num_items = 23
num_slots = 9
def _global_sampling_fn():
return np.random.randint(-10, 10, [global_dim])

def _item_sampling_fn():
return np.random.randint(-2, 3, [item_dim])
scores_weight_matrix = (np.reshape(
np.arange(global_dim * item_dim, dtype=np.float),
newshape=[item_dim, global_dim]) - 10) / 5
env = ranking_environment.RankingPyEnvironment(
_global_sampling_fn,
_item_sampling_fn,
num_items=num_items,
num_slots=num_slots,
scores_weight_matrix=scores_weight_matrix,
feedback_model=ranking_environment.FeedbackModel.SCORE_VECTOR,
click_model=ranking_environment.ClickModel.DISTANCE_BASED,
distance_threshold=10.0,
batch_size=batch_size)

chosen_items = np.array([0, 2, 9, 1, 2])
chosen_values = np.array([6, 8, 4, 2, 3])
score_vector = env._cascading_to_scorevector(chosen_items, chosen_values)
self.assertAllEqual(score_vector.shape, [batch_size, num_slots])

# The third row is all zeros because `chosen_item == 9` means no click.
self.assertAllEqual(score_vector, [[6, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 8, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 2, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 3, 0, 0, 0, 0, 0, 0]])


if __name__ == '__main__':
Expand Down

0 comments on commit 4e3d0d0

Please sign in to comment.