[go: nahoru, domu]

Skip to content

Commit

Permalink
Ranking train-eval changes for positonal bias environment.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 494696845
Change-Id: I1728b5d96fe45e727aa38fc5b9a5055cefa54e05
  • Loading branch information
bartokg authored and Copybara-Service committed Dec 12, 2022
1 parent 2db2c7c commit ec3778d
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 35 deletions.
96 changes: 62 additions & 34 deletions tf_agents/bandits/agents/examples/v2/train_eval_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,61 +40,87 @@
'`distance_based`.')
flags.DEFINE_float('distance_threshold', 10.0, 'If the diversity model is '
'`distance_based`, this is the distance threshold.')

flags.DEFINE_string('env_type', 'base', 'The environment used. Possible values'
' are `base` and `exp_pos_bias`.')

FLAGS = flags.FLAGS

# Environment and driver parameters.

BATCH_SIZE = 128
NUM_ITEMS = 101
NUM_SLOTS = 11
NUM_ITEMS = 1001
NUM_SLOTS = 5
GLOBAL_DIM = 50
ITEM_DIM = 40

TRAINING_LOOPS = 2000
STEPS_PER_LOOP = 2

LR = 0.005
LR = 0.05


def main(unused_argv):

def _global_sampling_fn():
return np.random.randint(-1, 1, [5]).astype(np.float32)
return np.random.randint(-1, 1, [GLOBAL_DIM]).astype(np.float32)

def _item_sampling_fn():
return np.random.randint(-2, 3, [4]).astype(np.float32)

# Inner product with the last global dimension ignored.
scores_weight_matrix = np.eye(4, 5, dtype=np.float32)
unnormalized = np.random.randint(-2, 3, [ITEM_DIM]).astype(np.float32)
return unnormalized / np.linalg.norm(unnormalized)

def _relevance_fn(global_obs, item_obs):
min_dim = min(GLOBAL_DIM, ITEM_DIM)
dot_prod = np.dot(global_obs[:min_dim],
item_obs[:min_dim]).astype(np.float32)
return 1 / (1 + np.exp(-dot_prod))

if FLAGS.env_type == 'exp_pos_bias':
positional_biases = list(1.0 / np.arange(1, NUM_SLOTS + 1)**1.3)
env = ranking_environment.ExplicitPositionalBiasRankingEnvironment(
_global_sampling_fn,
_item_sampling_fn,
_relevance_fn,
NUM_ITEMS,
positional_biases,
batch_size=BATCH_SIZE)
feedback_model = ranking_agent.FeedbackModel.SCORE_VECTOR
elif FLAGS.env_type == 'base':
# Inner product with the excess dimensions ignored.
scores_weight_matrix = np.eye(ITEM_DIM, GLOBAL_DIM, dtype=np.float32)

feedback_model = ranking_agent.FeedbackModel.SCORE_VECTOR
if FLAGS.feedback_model == 'cascading':
feedback_model = ranking_agent.FeedbackModel.CASCADING
else:
raise NotImplementedError('Feedback model {} not implemented'.format(
FLAGS.feedback_model))
if FLAGS.click_model == 'ghost_actions':
click_model = ranking_environment.ClickModel.GHOST_ACTIONS
elif FLAGS.click_model == 'distance_based':
click_model = ranking_environment.ClickModel.DISTANCE_BASED
else:
raise NotImplementedError('Diversity mode {} not implemented'.format(
FLAGS.click_mode))

env = ranking_environment.RankingPyEnvironment(
_global_sampling_fn,
_item_sampling_fn,
num_items=NUM_ITEMS,
num_slots=NUM_SLOTS,
scores_weight_matrix=scores_weight_matrix,
# TODO(b/247995883): Merge the two feedback model enums from the agent
# and the enviroment.
feedback_model=feedback_model.value,
click_model=click_model,
distance_threshold=FLAGS.distance_threshold,
batch_size=BATCH_SIZE)

if FLAGS.feedback_model == 'cascading':
feedback_model = ranking_environment.FeedbackModel.CASCADING
else:
raise NotImplementedError('Feedback model {} not implemented'.format(
FLAGS.feedback_model))
if FLAGS.click_model == 'ghost_actions':
click_model = ranking_environment.ClickModel.GHOST_ACTIONS
elif FLAGS.click_model == 'distance_based':
click_model = ranking_environment.ClickModel.DISTANCE_BASED
else:
raise NotImplementedError('Diversity mode {} not implemented'.format(
FLAGS.click_mode))

env = ranking_environment.RankingPyEnvironment(
_global_sampling_fn,
_item_sampling_fn,
num_items=NUM_ITEMS,
num_slots=NUM_SLOTS,
scores_weight_matrix=scores_weight_matrix,
feedback_model=feedback_model,
click_model=click_model,
distance_threshold=FLAGS.distance_threshold,
batch_size=BATCH_SIZE)
environment = tf_py_environment.TFPyEnvironment(env)

obs_spec = environment.observation_spec()
network = (
global_and_arm_feature_network.create_feed_forward_common_tower_network(
obs_spec, (20, 10), (20, 10), (20, 10)))
obs_spec, (40, 10), (40, 10), (20, 10)))
if FLAGS.policy_type == 'cosine_distance':
policy_type = ranking_agent.RankingPolicyType.COSINE_DISTANCE
elif FLAGS.policy_type == 'no_penalty':
Expand All @@ -108,6 +134,7 @@ def _item_sampling_fn():
scoring_network=network,
optimizer=tf.compat.v1.train.AdamOptimizer(learning_rate=LR),
policy_type=policy_type,
feedback_model=feedback_model,
summarize_grads_and_vars=True)

def order_items_from_action_fn(orig_trajectory):
Expand Down Expand Up @@ -164,7 +191,8 @@ def order_items_from_action_fn(orig_trajectory):
environment=environment,
training_loops=TRAINING_LOOPS,
steps_per_loop=STEPS_PER_LOOP,
training_data_spec_transformation_fn=order_items_from_action_fn)
training_data_spec_transformation_fn=order_items_from_action_fn,
save_policy=False)


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion tf_agents/bandits/environments/ranking_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ class ExplicitPositionalBiasRankingEnvironment(

def __init__(self, global_sampling_fn: Callable[[], types.Array],
item_sampling_fn: Callable[[], types.Array],
relevance_fn: Callable[[], float],
relevance_fn: Callable[[types.Array, types.Array], float],
num_items: int,
observation_probs: Sequence[float],
batch_size: int = 1,
Expand Down

0 comments on commit ec3778d

Please sign in to comment.