限りなく院生に近いニート@エストニア

エストニアという国で一人ダラダラしてます。

A/B test with python (Bandit)

In this article, I am going to write a Bayesian Bandit algorithm.

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import beta

NUM_TRIALS = 2000
BANDIT_PROBABILITIES = [0.2, 0.5, 0.75]

Bandit probabilities are divided into three value just in case. Quick favorite.
Trial number is 2000.

So, I am going to define a class called Bandit.

class Bandit:
    def __init__(self, p):
        self.p = p
        self.a = 1
        self.b = 1
        
    def pull(self): # arm of slot machine
        return np.random.random() < self.p
    
    def sample(self): # sample from current beta distribution
        return np.random.beta(self.a, self.b)
    
    def update(self, x):
        self.a += x
        self.b += 1 - x

This works like slot machine.
a, b are the beta parameters defined a uniform distribution in the beginning and p is the probability of winning.

This class should have a ability to pull the slot machine defined definition pull() returns the random number and we can get also sample from its current distribution.

Lastly, we have also update function. x is either 0 or 1.

def plot(bandits, trial):
    x = np.linspace(0, 1, 200)
    for b in bandits:
        y = beta.pdf(x, b.a, b.b) 
        plt.plot(x, y, label="real p: %.4f" % b.p)
    plt.title("Bandit distributions after %s trials" % trial)
    plt.legend()
    plt.show()

This is just going to plot PDF of each bandit and we can compare it on the same chart.

OK. Then, we need to run a actual experiment.

def experiment():
    bandits = [Bandit(p) for p in BANDIT_PROBABILITIES]
    
    sample_points = [5, 10, 20, 50, 100, 200, 500, 1000, 1500, 1999]
    
    for i in range(NUM_TRIALS):
        bestb = None
        maxsample = -1
        allsamples = []
        for b in bandits:
            sample = b.sample()
            allsamples.append("%.4f" % sample)
            if sample > maxsample:
                maxsample = sample
                bestb = b
        if i in sample_points:
            print("Current samples: %s" % allsamples)
            plot(bandits, i)
            
        x = bestb.pull()
        bestb.update(x)

if __name__ == '__main__':
    experiment()

Bandit should be initialize at first.
Sample points are where we want to reasonably show a plot. Any values are ok.
We also keep track of the maximum sample and best will be.