This animation is to illustrate convexity in mean-parameters of the conjugate likelihood and prior.
Note, that posterior mean is always between MLE and prior mean.

Author: [Evgenii Egorov](mailto:egorov.evgenyy@ya.ru)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import numpy as np
import scipy.stats as stats

class BetaPosterior:
    def __init__(self, prior):
        self.tau = prior
        
    def update(self, X): 
        self.tau[0] += np.sum(X)
        self.tau[1] += X.shape[0] - np.sum(X)

    def sample(self, size):
        return stats.beta.rvs(*self.tau, size=size)
    
    def mean(self):
        return (self.tau[0]) / (self.tau[0] + self.tau[1])
    
class BetaMLE:
    def __init__(self):
        self.N = 0.
        self.sum_stat = 0.
    
    def update(self, X):
        self.N += X.shape[0]
        self.sum_stat += np.sum(X)
    
    def mle(self):
        return self.sum_stat / self.N

In [None]:
p_true = 0.1
batch_size = 1
epoch_num = 350

prior_params = np.array([1., 1.])

p_posterior = BetaPosterior(prior_params.copy())
p_mle = BetaMLE()

In [None]:
p_posterior = BetaPosterior(prior_params.copy())
p_mle = BetaMLE()

from IPython.display import clear_output

mle_ls = []
posterior_ls = []

for epoch in range(epoch_num):
    X_sample = stats.bernoulli.rvs(p_true, size=batch_size)
    
    p_mle.update(X_sample)
    p_posterior.update(X_sample)
    
    mle_ls.append(p_mle.mle())
    posterior_ls.append(p_posterior.mean())
    
    p_sample = p_posterior.sample(500)
    y_max = np.max(np.histogram(p_sample)[0])
    
    plt.figure(figsize=(10,10))
    plt.hist(p_sample, alpha=0.3);
    plt.vlines(mle_ls[-1], 0, y_max, color='green', label='mle');
    plt.vlines(posterior_ls[-1], 0, y_max, color='orange', label='posterior mean');
    plt.vlines(p_true, 0, y_max, color='red', label='true');
    plt.vlines(prior_params[0] / (prior_params[0] + prior_params[1]), 0, y_max, color='black', label='prior mean')
    plt.legend(loc=2);
    plt.xlim(0, 0.7)
    plt.yticks(ticks=[i*20 for i in range(5)])
    plt.grid();
    
    plt.show();
    plt.pause(0.3);
    clear_output(wait=True);