# Cmput 455 sample code # How often do bandits based on Bernoulli experiments make the wrong choice? # Written by Martin Mueller from scipy.stats import binom # If single experiment returns 1 with probability p: # What is the probability of getting exactly k wins in n experiments? # Also see: # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.binom.html def binomial(k, n, p): return binom.pmf(k, n, p) # probability that winrate of n1 tries is worse that k2/n2 # Sum over all outcomes k1 for which winrate k1/n1 < k2/n2 # for outcomes where winrates are equal, we count the # chance of making the right decision as 50-50 # To avoid floating point errors, multiply all by n1*n2 to keep # computations to integers. def worse(p1, n1, k2, n2): sum = 0 k1 = 0 while(k1*n2 <= k2*n1): prob = binomial(k1, n1, p1) if (k1*n2 < k2*n1): sum += prob else: # winrate equal, count as half. sum += prob/2 k1 += 1 return sum # arm1 with win probability p1 # arm2 with win probability p2 # p1 is better than p2 # p1 was pulled n1 times # p2 was pulled n2 times # What is the probability that p2 got the better evaluation? def wrongChoice (n1, p1, n2, p2): assert p1 > p2 sum = 0 for k in range(n2+1): # k wins out of n2 for arm 2 sum += binomial(k, n2, p2) * worse(p1, n1, k, n2) # prob that: arm2 has k wins, AND arm1 has worse winrate than arm2 return sum def test(p1, p2, maxN, printall = True): print("p1 = {}, p2 = {}".format(p1, p2)) nextPrint = 1 for n in range(1, maxN + 1): result = wrongChoice (n, p1, n, p2) if (n == nextPrint): print("Both have {} simulations. Prob. of wrong arm choice {}" .format(n, result)) if (printall): nextPrint += 1 else: nextPrint *= 2 test(p1 = 0.9, p2 = 0.2, maxN = 10) test(p1 = 0.5, p2 = 0.4, maxN = 10) test(p1 = 0.5, p2 = 0.49, maxN = 10) test(p1 = 0.5, p2 = 0.49999, maxN = 10) test(p1 = 0.9, p2 = 0.2, maxN = 32, printall = False) test(p1 = 0.5, p2 = 0.4, maxN = 32, printall = False) test(p1 = 0.5, p2 = 0.49, maxN = 32, printall = False) test(p1 = 0.5, p2 = 0.49999, maxN = 32, printall = False)