Beyond the Basics: Reinforcement Learning with Jax — Part II: Developing an Exploitative Alternative to A/B Testing | by Lando L | Jun, 2023

Think back to the last time you scrolled through your video streaming service looking for a new movie to watch. If you did not already know what to look for, you may have been affected by the paradox of choice¹. Having such an array of potentially good movies can make it difficult to make informed decisions. Instead, we often rely on simpler decisions that can be made almost instantly. As a result, we tend to compare videos based on their thumbnails and titles. Knowing of this effect, video production and streaming companies strive to optimise thumbnails and titles in order to increase their movies’ click rates. Without prior knowledge of the audience’s preferences, the process of creating engaging thumbnails and titles becomes one of trial-and-error. (I hope this rings a bell by now!)

The traditional approach

Let us imagine ourselves in the shoes of the decision-maker tasked with choosing the right thumbnail and title for a newly-published movie on the streaming website. In addition to the original movie thumbnail and title, we are given a set of proposals consisting of two thumbnails and two titles. The traditional data-driven approach in this scenario would be to run an A/B test and compare the click rates of the original version with that of the proposed versions. Although the advantage of A/B tests is that, when given enough samples, it can identify whether or not the measured difference in click rate is statistically significant, it limits us to compare only two variants simultaneously and often requires extensive amounts of samples to determine its result. Furthermore, in the case one of the variants is performing drastically better than the other, we are still forced to serve the worse performing variant to half our customers until the end of the experiment, potentially leading to money losses.

The Reinforcement Learning approach

Alternatively, we can set up a multi-arm bandit (MAB) experiment. MABs are a simplification of Reinforcement Learning (RL) problems, as they can be defined entirely by a set of actions and a reward function. This makes them effectively like finite Markov-Decision-Processes (MDPs) with only one state. Unlike MDPs, the actions in MABs are independent of each other, meaning that at each time step, the same set of actions will provide the same reward distribution. Therefore, if we know the exact reward distribution of the available actions, we can simply use a greedy approach to choose the action with the highest payout. In contrast, MDPs require us to sometimes take ‘suboptimal’ actions to reach a highly rewarding state. However, the problems that can be modelled with MABs are fewer than those that can be modelled with MDPs. For example, in the music practice environment from the last post, it was necessary to introduce states to model our friend’s mood, something which cannot be done with MABs.

We can model the challenge from our video streaming platform example by defining the actions and reward function as follows:

  • Every time a user visits the website, we must choose which version to display. Our options consist of showing the original movie version, or one of the four variations created by combining the two thumbnails and two titles.
  • Once we selected the variant to display, the user has the option to watch or not watch our film. Consequently, our reward function is binary, yielding a reward of zero if the user choses not to view the movie and a reward of one if they decide to watch the movie.

The advantages of the MAB approach compared with traditional A/B testing are numerous; it allows for an unlimited number of variants to be tested simultaneously, dynamically lowers the frequency of poorer performing variants, and converges with fewer samples — all of which lead to cost savings. The downside is that it does not provide the statistical significance of a traditional A/B test; however, when the implications of the experiment are not affecting people’s wellbeing (as is the case in most instances), statistical significance is not strictly necessary.

Finally, we will start writing code. For the code examples of this course, we will be using Python as our programming language, and will largely use Jax as our Machine Learning (ML) framework. Developed by Google, Jax is a Python library that is specifically designed for ML research. Unlike Tensorflow and PyTorch, it is built with the functional programming (FP) paradigm, making it highly composable and promoting the concept of pure functions that have no side effects. This means that all state changes, such as parameter updates or splitting of random generators, must be done explicitly. Although this may require more lines of code than their object-oriented programming (OOP) equivalent, it gives developers full control over state changes, leading to increased understanding of the code and fewer surprises.

(The full code is available as a Jupyter notebook on GitHub and Kaggle.)

Implementing the environment

As the first step, we implement the environment for our video streaming platform challenge. At a high-level, whenever someone visits the platform to scroll through the available movie options, the environment needs to ask the agent which thumbnails and titles to display, and then communicate to the agent which movie the visitor chose. The complexity of this task will depend on the maturity of the platform’s architecture, which could range from simply changing a few lines of code to developing an entirely new service.

For the purpose of this course, we will keep the environment simple by implementing it as a three-step process:

  1. Asking the agent which of the five variants it wishes to display.
  2. Simulating the visitor’s choice.
  3. Communicating the agent its reward for the decision.

Since we are using Jax as our main ML framework, we need to import Jax and the three modules numpy, lax and random²:

import jax

# Numpy API with hardware acceleration and automatic differentiation
from jax import numpy as jnp

# Low level operators
from jax import lax

# API for working with pseudorandom number generators
from jax import random

Next, we set the constants of our environment, consisting of a random seed to ensure replicability, the number of visitors we want to simulate, and the expected click rates. It is important to note that, in the real world, the click rates are considered unknown. Since we are not running an actual experiment, we must simulate the visitors’ click behaviour. To do this, we define different click rates for the five variants to imitate the visitors’ preference, with the original variant having a click rate of 4.2%, and the four variants having click rates of 3%, 3.5%, 3.8%, and 4.5%, respectively.

# Random seed to make our experiment replicable 
SEED = 42

# Number of visitors we want to simulate
NUM_VISITS = 10000

# Expected click rates for the five variants with the
CLICK_RATES = [0.042, 0.03, 0.035, 0.038, 0.045]

Finally, we define a generic function that simulates a user visit, or a single step in our environment. This function comprises three steps and is close to the high-level implementation we set out earlier:

  1. Executing the agent’s policy to determine which variant to display to the user.
  2. Randomly simulating whether the user clicked on the movie or not.
  3. Updating the agent’s parameters based on the variant shown and the associated reward.
def visit(state, timestep, click_rates, policy_fn, update_fn):
Simulates a user visit.

# Unpacking the environment state into
# the agent's parameters and the random number generator
params, rng = state

# Splitting the random number generator
next_rng, policy_rng, user_rng = random.split(rng, num=3)

# Selecting the variant to show the user, based on
# the given policy, the agent's parameters, and the current timestep
variant = policy_fn(params, timestep, policy_rng)

# Randomly simulating the user click, based on
# the variant's click rate
clicked = random.uniform(user_rng) < click_rates[variant]

# Calculating the agent's updated parameters, based on
# the current parameters, the selected variant,
# and whether or not the user clicked
next_params = update_fn(params, variant, clicked)

# Returning the updated experiment state (params and rng) and
# whether or not the user clicked
return (next_params, next_rng), clicked

Before we continue, let us discuss the function signature. If we fix the parameters click_rates, policy_fn, and update_fn, the visit function takes the environment state and the current timestep as its parameters and returns a tuple containing the next environment state and a boolean value that encodes the binary reward. In Haskell notation the function signature would looks like this:

state -> timestep -> (state, Bool)

Hence, to simulate the n-th step in our environment we pass the function the n-th state and n-th timestep, and we receive the (n+1)-th state and the n-th reward. For the (n+1)-th step, we use the same function, by passing it the (n+1)-th state returned by the last function call and the time step n+1.

# Initialising the state
s0 = ...

# Simulating the time steps 0, 1, and 2
s1, r0 = visit(s0, 0)
s2, r1 = visit(s1, 1)
s3, r2 = visit(s2, 2)

Having to pass a state and timestep parameter to every call of the visit function may seem cumbersome to people accustomed to OOP. However, in this example, using a pure function implementation for the environment has several advantages over an object-oriented approach. Firstly, it explicitly states which parameters the environment depends on, thus eliminating any hidden global variables that may influence the outcome. Secondly, it makes it easier to test the environment with various states and timesteps without having to set and read the environment’s internal state. Lastly, we will discover a useful function from the Jax library, which offers state management, thus drastically reducing the code needed at the call-side.

Implementing the policies

With the environment in place, we can now implement our decision-making process or policies. We have already identified an optimal policy for MABs. Given the action-value distribution, the best action to take is the one with the highest expected payout, which is referred to as exploitation. However, since the actual action-value distribution is unknown to us, we must try out all available options at least once in order to estimate the distribution, a process often referred to as exploration. The delicate balance between exploration and exploitation is a recurrent theme in RL and will be discussed in more detail throughout the course.

The three policies we will cover in this blog post are the epsilon-greedy policy, the Boltzmann policy, and the upper-confidence-bound policy. All of which are action-value methods, meaning they explicitly estimate the values of all actions and base their decision making on these estimates. At the end we will cover a bonus policy, which is based on the Thompson sampling heuristic and is considered a bayesian method.

Action-value methods

The simplest way for an agent to estimate action-values is to average the rewards it has received for each variant so far. Here, Q denotes the agent’s action-value estimate of variant a at timestep t.

Rather than re-calculating the average every round, we can implement action-value estimation in an incremental fashion. Where Q stores the current action-value estimate of each variant a, N counts how often a was shown, and R denotes the reward received at timestep t.

Let us implement two functions to work with action-value estimates. The first function initialises the lookup tables Q and N for each variant a, setting the estimates of all variants to an optimistic initial value of 1 (or a click rate of 100%). The second function updates Q and N according to the incremental update definition described above.

def action_value_init(num_variants):
Returns the initial action values

return {
'n': jnp.ones(num_variants, dtype=jnp.int32),
'q': jnp.ones(num_variants, dtype=jnp.float32)

def action_value_update(params, variant, clicked):
Calculates the updated action values

# Reading n and q parameters of the selected variant
n, q = params['n'][variant], params['q'][variant]

# Converting the boolean clicked variable to a float value
r = clicked.astype(jnp.float32)

return {
# Incrementing the counter of the taken action by one
'n': params['n'].at[variant].add(1),

# Incrementally updating the action-value estimate
'q': params['q'].at[variant].add((r - q) / n)

We chose to implement the estimate action-value initialisation and update using a functional approach, similar to the environment function implementation. Jax arrays, unlike numpy arrays, are immutable and therefore cannot be updated in-place; instead, each update returns a new copy with the applied changes, while the original array remains unchanged.

Epsilon-greedy policy

The epsilon-greedy policy defines a stochastic approach to balancing the exploration and exploitation trade-off. Guided by the hyperparameter ε, it randomly decided between selecting the variant a with the highest action-value Q or selecting a uniformly random variant.

In Jax we can define conditional policies using the cond function. It takes a predicate, two functions, and a variable amount of arguments. Depending on the result of the predicate cond employs one of the two functions passing it the given arguments.

def epsilon_greedy_policy(params, timestep, rng, epsilon):
Randomly selects either the variant with highest action-value,
or an arbitrary variant.

# Selecting a random variant
def explore(q, rng):
return random.choice(rng, jnp.arange(len(q)))

# Selecting the variant with the highest action-value estimate
def exploit(q, rng):
return jnp.argmax(q)

# Splitting the random number generator
uniform_rng, choice_rng = random.split(rng)

# Deciding randomly whether to explore or to exploit
return lax.cond(
random.uniform(uniform_rng) < epsilon,

Boltzmann policy

The Boltzmann or softmax policy is similar to the epsilon-greedy policy in that it is a stochastic policy based on action-value estimates. This approach randomly samples a variant a from the probability distribution that results from applying the softmax function to the action-value estimates Q. The exploration-exploitation trade-off can be controlled through the temperature hyperparameter τ, where lower temperatures favour exploitation and higher temperatures promote exploration. The probability P of each variant to be selected is defined by:

In Jax, this can be easily implemented by utilizing the choice function from the random module, parameterised by the softmax function applied to the action-value estimates.

def boltzmann_policy(params, timestep, rng, tau):
Randomly selects a variant proportional to the current action-values

return random.choice(
# Turning the action-value estimates into a probability distribution
# by applying the softmax function controlled by tau
p=jax.nn.softmax(params['q'] / tau)

Upper-Confidence-Bound policy

We will now discuss a policy with a deterministic approach to balancing exploration and exploitation. Like the formerly discussed policies, it encourages exploitation by prioritising variants with high action-value estimates. However, instead of relying on stochasticity for exploration, it leverages a heuristic that encourages the selection of variants with low selection counts.

The heuristic accomplishes this by remaining optimistic in the face of uncertainty. Meaning, every variant is given the benefit of the doubt of being better than our current action-value estimate. During the experiment, each time a variant is selected and a real reward is observed, we get more confident in our action-value estimate and decrease the benefit of the doubt for that variant.

Formally, we define this optimistic guess as the variants’ upper-confidence-bound (UCB) which is scaled by the confidence hyperparameter c and added it to the current action-value estimate. Finally, we select the variant with the highest sum.

The UCB policy is the first we discovered that rewards exploration as well as exploitation:

  • Given two variants with the same action-value estimate Q, we will select the variant with a lower selection count N.
  • Given two variants with the same selection count N, we will select the variant with a higher action-value estimate Q.

To ensure a consistent function definition for all policies, the UCB policy takes a random number generator parameter, even though it is a deterministic algorithm.

def upper_confidence_bound_policy(params, timestep, rng, confidence):
Selects the variant with highest action-value plus upper confidence bound

# Read n and q parameters
n, q = params['n'], params['q']

# Calculating each variant's upper confidence bound
# and selecting the variant with the highest value
return jnp.argmax(q + confidence * jnp.sqrt(jnp.log(timestep) / n))

Bayesian methods

The discussed action-value methods make point estimates for the unknown click rates of our five variants. However, we now adopt a more Bayesian approach and treat the variants’ click rates as a set of independent random variables. Specifically, we define our current belief of a variants click rate C by modelling it as a random variable following a beta distribution².

The beta distribution is characterized by two parameters, a and b, which can be interpreted as the number of times variant i was clicked versus the number of times it was not clicked when shown. When comparing bayesian methods with action-value methods, we can use the expected value E of the random variable C to define our best guess, which can be determined by dividing the number of times a variant was clicked by the number of times it was shown:

We define two functions to work with beta distributions, analogous to the action-value methods. The first initialises a uniform beta prior for each variant, while the second function calculates the posterior beta distribution by incrementing either the a or the b parameter by one.

def beta_init(num_variants):
Returns the initial hyperparameters of the beta distribution

return {
'a': jnp.ones(num_variants, dtype=jnp.int32),
'b': jnp.ones(num_variants, dtype=jnp.int32)

def beta_update(params, variant, clicked):
Calculates the updated hyperparameters of the beta distribution

# Incrementing alpha by one
def increment_alpha(a, b):
return {'a':[variant].add(1), 'b': b}

# Incrementing beta by one
def increment_beta(a, b):
return {'b':[variant].add(1), 'a': a}

# Incrementing either alpha or beta
# depending on whether or not the user clicked
return lax.cond(

Thompson sampling policy

The TS policy is based on a two-step heuristic, which works by drawing random click rate samples from our beta distribution, and then selecting the variant with the highest sample click rate. The feedback we receive is then instantaneously incorporated into the beta distribution of that variant, narrowing the distribution closer to the actual click rate.

Like the UCB policy, this approach rewards both exploration and exploitation:

  • Given two variants with the same mean, the variant with a higher variance has a higher chance of being selected, since it has a broader distribution and will lead more often to higher action-values when sampled.
  • Given two variants with the same variance, the variant with the higher mean is selected more often, as it is more likely to sample a greater action-value.

For the implementation of the TS policy we use Jax’s random module to sample random click rates based on the variants’ beta distribution and then select the variant with the highest sample.

def thompson_policy(params, timestep, rng):
Randomly sampling click rates for all variants
and selecting the variant with the highest sample

return jnp.argmax(random.beta(rng, params['a'], params['b']))

Implementing the evaluation

With the environment and policies in place, we can finally conduct our experiment and compare the results. Before we continue, I want to highlight that this experiment is intended to demonstrate how the algorithms work, not to empirically evaluate their performance. In order to execute the following implementations, we need to import the partial function from Python’s functools library and pyplot from matplotlib:

from functools import partial
from matplotlib import pyplot as plt

The evaluate function is responsible for executing the visit function, as guided by the set of constants, the parameter initialisation and update functions, and the policy. The output of the evaluation is the final environment state, which includes the policy’s final parameters and final random number generator state, as well as the click history. We leverage Jax’s scan function to ensure that the experiment state is carried over and user clicks are accumulated. Moreover, just-in-time (JIT) compilation is employed to optimise performance, while partial is utilised to fix the click_rate, policy_fn, and update_fn parameters of the visit function so it matches the expected signature.

def evaluate(policy_fn, init_fn, update_fn):
Simulating the environment for NUM_VISITS users
while accumulating the click history

return lax.scan(
# Compiling the visit function using just-in-time (JIT) compilation
# for better performance
# Partially applying the visit function by fixing
# the click_rates, policy_fn, and update_fn parameters

# Initialising the experiment state using
# init_fn and a new PRNG key
(init_fn(len(CLICK_RATES)), random.PRNGKey(SEED)),

# Setting the number of steps in our environment
jnp.arange(1, NUM_VISITS + 1)

The regret function is the last component of our evaluation. In RL lingo, regret is defined as the amount of reward we miss out on by taking a suboptimal action, and can only be calculated when the optimal action is known. Given the click history, our regret function calculates the regret of the action taken for every step of the environment.

def regret(history):
Calculates the regret for every action in the environment history

# Calculating regret with regard to picking the optimal (0.045) variant
def fn(acc, reward):
n, v = acc[0] + 1, acc[1] + reward
return (n, v), 0.045 - (v / n)

# Calculating regret values over entire history
_, result = lax.scan(
(jnp.array(0), jnp.array(0)),

return result

Next we run the evaluation for all four policies and visualise the regret. Note that the hyperparameters for the policies have not been fine-tuned, instead they are set to generic default values that are suitable for a wide variety of MAB problems.

# Epsilon greedy policy
(epsilon_greedy_params, _), epsilon_greedy_history = evaluate(
policy_fn=partial(epsilon_greedy_policy, epsilon=0.1),

# Boltzmann policy
(boltzmann_params, _), boltzmann_history = evaluate(
policy_fn=partial(boltzmann_policy, tau=1.0),

# Upper confidence bound policy
(ucb_params, _), ucb_history = evaluate(
policy_fn=partial(upper_confidence_bound_policy, confidence=2),

# Thompson sampling policy
(ts_params, _), ts_history = evaluate(

# Visualisation
fig, ax = plt.subplots(figsize=(16, 8))

x = jnp.arange(1, NUM_VISITS + 1)

ax.set_xlabel('Number of visits')

ax.plot(x, jnp.repeat(jnp.mean(jnp.array(CLICK_RATES)), NUM_VISITS), label='A/B Testing')
ax.plot(x, regret(epsilon_greedy_history), label='Espilon Greedy Policy')
ax.plot(x, regret(boltzmann_history), label='Boltzmann Policy')
ax.plot(x, regret(ucb_history), label='UCB Policy')
ax.plot(x, regret(ts_history), label='TS Policy')


The resulting graph is shown below, plotting the regret of our policies over the number of visits. We can clearly see that all policies outperform a hypothetical A/B testing scenario in terms of regret. The epsilon-greedy and TS policies appear to perform slightly better than the boltzmann and UCB policies in this particular scenario.

Source link

Leave a Comment