[go: nahoru, domu]

Skip to content

Commit

Permalink
A ranking environment that explicitly models observation probabilities.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 493826467
Change-Id: Id3c2c7007a69d2b47eeb2ccc60fae1ca66b24030
  • Loading branch information
bartokg authored and Copybara-Service committed Dec 8, 2022
1 parent 6d69c46 commit 31b7b7a
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 1 deletion.
121 changes: 120 additions & 1 deletion tf_agents/bandits/environments/ranking_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
threshold, no item is selected by the user.
"""
from typing import Optional, Callable, Text
from typing import Optional, Callable, Sequence, Text

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -279,3 +279,122 @@ def _choose_items_distance_based(self, global_obs, slotted_items):
[scores,
np.ones([self._batch_size, 1]) * self._distance_threshold], axis=1)
return np.argmax(scores, axis=1)


class ExplicitPositionalBiasRankingEnvironment(
bandit_py_environment.BanditPyEnvironment):
"""A ranking environment in which one can explicitly set positional bias.
This environment assumes that the user's click is governed by two independent
values: the relevance and the observation probability. Relevance is based on
a random function whose input is the global and item features. The observation
probability is a parameter list that sets observation probabilities for all of
the slot positions. The observation prbabilities don't denpend on the context
or the items placed in the slots.
The user model: an item is clicked if it is observed and relevant. Hence,
multiple items can be clicked in one sample.
"""
_observation: types.NestedArray

def __init__(self, global_sampling_fn: Callable[[], types.Array],
item_sampling_fn: Callable[[], types.Array],
relevance_fn: Callable[[], float],
num_items: int,
observation_probs: Sequence[float],
batch_size: int = 1,
name: Optional[Text] = None):
"""Initializes an instance of `ExplicitPositionalBiasRankingEnvironment`.
Args:
global_sampling_fn: A function that outputs a random 1d array or
list of ints or floats. This output is the global context. Its shape and
type must be consistent across calls.
item_sampling_fn: A function that outputs a random 1 array or list
of ints or floats (same type as the output of
`global_context_sampling_fn`). This output is the per-arm context. Its
shape must be consistent across calls.
relevance_fn: A function that, called with global features and features of
one item, returns click probabilities, given the item was observed.
num_items: (int) the number of items in every sample.
observation_probs: Observation probabilities for all slots. The length of
this list determines the number of slots.
batch_size: The batch size.
name: The name of this environment instance.
"""
self._global_sampling_fn = global_sampling_fn
self._item_sampling_fn = item_sampling_fn
self._num_items = num_items

self._num_slots = len(observation_probs)
self._observation_probs = np.array(observation_probs)
if np.any(self._observation_probs > 1) or np.any(
self._observation_probs < 0):
raise ValueError('Observation probabilities need to be in [0, 1].')
self._relevance_fn = relevance_fn
self._batch_size = batch_size

global_spec = array_spec.ArraySpec.from_array(global_sampling_fn())
item_spec = array_spec.add_outer_dims_nest(
array_spec.ArraySpec.from_array(item_sampling_fn()),
(num_items,))
observation_spec = {GLOBAL_KEY: global_spec, PER_ARM_KEY: item_spec}
self._global_dim = global_spec.shape[0]

action_spec = array_spec.BoundedArraySpec(
shape=(self._num_slots,),
dtype=np.int32,
minimum=0,
maximum=num_items - 1,
name='action')
reward_spec = array_spec.ArraySpec(
shape=[self._num_slots], dtype=np.float32, name='score_vector')

super(ExplicitPositionalBiasRankingEnvironment, self).__init__(
observation_spec, action_spec, reward_spec, name=name)

def _observe(self) -> types.NestedArray:
global_obs = np.stack(
[self._global_sampling_fn() for _ in range(self._batch_size)])
item_obs = np.reshape([
self._item_sampling_fn()
for _ in range(self._batch_size * self._num_items)
], (self._batch_size, self._num_items, -1))
self._observation = {GLOBAL_KEY: global_obs, PER_ARM_KEY: item_obs}
return self._observation

def _apply_action(self, action: np.ndarray) -> types.Array:
if action.shape[0] != self.batch_size:
raise ValueError('Number of actions must match batch size.')
global_obs = self._observation[GLOBAL_KEY]
item_obs = self._observation[PER_ARM_KEY]
batch_size_range = range(self.batch_size)
slotted_items = item_obs[np.expand_dims(batch_size_range, axis=1), action]
relevances = self._get_relevances(global_obs, slotted_items)

# The `relevances` array is of shape `[batch_size, num_slots]`, the
# `observation_probs` array is of shape `[num_slots]`. With broadcasting,
# `click_probabilities` becomes an array of shape `[batch_size, num_slots]`.
click_probabilities = relevances * self._observation_probs
scores = np.random.binomial(1, click_probabilities).astype(np.float32)
return scores

def _get_relevances(self, global_obs, slotted_items):
"""Returns the relevance of each item in a batched action."""
s_range = range(self._num_slots)
b_range = range(self._batch_size)

relevances = np.array([[
self._relevance_fn(global_obs[i], slotted_items[i, j]) for j in s_range
] for i in b_range])
clipped_relevances = np.clip(relevances, 0., 1.)
if not np.all(relevances == clipped_relevances):
print('Warning: relevance probabilities outside of `[0, 1]`.')
return clipped_relevances

def batched(self) -> bool:
return True

@property
def batch_size(self) -> int:
return self._batch_size
61 changes: 61 additions & 0 deletions tf_agents/bandits/environments/ranking_environment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,5 +153,66 @@ def _item_sampling_fn():
[0, 0, 3, 0, 0, 0, 0, 0, 0]])


class ExplicitBiasEnvironmentTest(tf.test.TestCase, parameterized.TestCase):

@parameterized.parameters([{
'batch_size': 1,
'global_dim': 4,
'item_dim': 5,
'num_items': 7,
'num_slots': 5,
}, {
'batch_size': 8,
'global_dim': 12,
'item_dim': 4,
'num_items': 23,
'num_slots': 9,
}])
def test_explicit_bias_environment(self, batch_size, global_dim, item_dim,
num_items, num_slots):
def _global_sampling_fn():
return np.random.randint(-10, 10, [global_dim])

def _item_sampling_fn():
return np.random.randint(-2, 3, [item_dim])

def _relevance_fn(global_obs, item_obs):
min_dim = min(global_dim, item_dim)
dot_prod = np.dot(global_obs[:min_dim],
item_obs[:min_dim]).astype(np.float32)
return 1 / (1 + np.exp(-dot_prod))

positional_biases = list(0.75 - np.arange(num_slots) / (2 * num_slots))
env = ranking_environment.ExplicitPositionalBiasRankingEnvironment(
_global_sampling_fn,
_item_sampling_fn,
_relevance_fn,
num_items,
positional_biases,
batch_size,
)

time_step_spec = env.time_step_spec()
action_spec = env.action_spec()

random_policy = random_py_policy.RandomPyPolicy(
time_step_spec=time_step_spec, action_spec=action_spec)

for _ in range(5):
time_step = env.reset()
self.assertTrue(
check_unbatched_time_step_spec(
time_step=time_step,
time_step_spec=time_step_spec,
batch_size=env.batch_size))

action = random_policy.action(time_step).action
self.assertAllEqual(action.shape, [batch_size, num_slots])
self.assertAllGreaterEqual(action, 0)
time_step = env.step(action)
reward = time_step.reward
self.assertAllEqual(reward.shape, [batch_size, num_slots])


if __name__ == '__main__':
tf.test.main()

0 comments on commit 31b7b7a

Please sign in to comment.