Lab 10 - Sampling Controls¶
Goal: apply top-k and top-p filtering, then renormalize probabilities.
Info: Top-k and top-p filtering
Reduce candidate set before sampling, then renormalize. Lower values focus output; higher values add diversity. Configure per use case.
Info: Profile versioning
Log all temperature and top-p changes. Sampling control changes deserve change control like any config update.
In [ ]:
Copied!
base_probs = [0.40, 0.22, 0.14, 0.10, 0.08, 0.06]
labels = ['token_A', 'token_B', 'token_C', 'token_D', 'token_E', 'token_F']
def renormalize(items):
s = sum(p for _, p in items)
return [(label, p / s) for label, p in items]
print('Original:', list(zip(labels, base_probs)))
base_probs = [0.40, 0.22, 0.14, 0.10, 0.08, 0.06]
labels = ['token_A', 'token_B', 'token_C', 'token_D', 'token_E', 'token_F']
def renormalize(items):
s = sum(p for _, p in items)
return [(label, p / s) for label, p in items]
print('Original:', list(zip(labels, base_probs)))
In [ ]:
Copied!
def top_k(labels, probs, k):
pairs = sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)[:k]
return renormalize(pairs)
def top_p(labels, probs, p):
pairs = sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
keep = []
total = 0.0
for label, pr in pairs:
keep.append((label, pr))
total += pr
if total >= p:
break
return renormalize(keep)
print('top-k (k=3):', [(l, round(v, 4)) for l, v in top_k(labels, base_probs, 3)])
print('top-p (p=0.8):', [(l, round(v, 4)) for l, v in top_p(labels, base_probs, 0.8)])
def top_k(labels, probs, k):
pairs = sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)[:k]
return renormalize(pairs)
def top_p(labels, probs, p):
pairs = sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
keep = []
total = 0.0
for label, pr in pairs:
keep.append((label, pr))
total += pr
if total >= p:
break
return renormalize(keep)
print('top-k (k=3):', [(l, round(v, 4)) for l, v in top_k(labels, base_probs, 3)])
print('top-p (p=0.8):', [(l, round(v, 4)) for l, v in top_p(labels, base_probs, 0.8)])
Visualization: original vs top-k vs top-p¶
This compares how probability mass shifts after filtering and renormalization.
In [ ]:
Copied!
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
labels = ['A','B','C','D','E','F']
base_probs = [0.40, 0.22, 0.14, 0.10, 0.08, 0.06]
def renormalize(items):
s = sum(p for _, p in items)
return [(label, p / s) for label, p in items]
def top_k(labels, probs, k):
pairs = sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)[:k]
return dict(renormalize(pairs))
def top_p(labels, probs, p):
pairs = sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
keep=[]; total=0
for label, pr in pairs:
keep.append((label, pr)); total += pr
if total >= p: break
return dict(renormalize(keep))
kdist = top_k(labels, base_probs, 3)
pdist = top_p(labels, base_probs, 0.8)
orig = base_probs
kvals = [kdist.get(l, 0.0) for l in labels]
pvals = [pdist.get(l, 0.0) for l in labels]
x = range(len(labels)); w = 0.26
plt.figure(figsize=(10,4))
plt.bar([i-w for i in x], orig, width=w, label='Original')
plt.bar(x, kvals, width=w, label='Top-k=3')
plt.bar([i+w for i in x], pvals, width=w, label='Top-p=0.8')
plt.xticks(list(x), labels)
plt.title('Distribution comparison')
plt.ylabel('Probability')
plt.legend()
plt.tight_layout()
plt.show()
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
labels = ['A','B','C','D','E','F']
base_probs = [0.40, 0.22, 0.14, 0.10, 0.08, 0.06]
def renormalize(items):
s = sum(p for _, p in items)
return [(label, p / s) for label, p in items]
def top_k(labels, probs, k):
pairs = sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)[:k]
return dict(renormalize(pairs))
def top_p(labels, probs, p):
pairs = sorted(zip(labels, probs), key=lambda x: x[1], reverse=True)
keep=[]; total=0
for label, pr in pairs:
keep.append((label, pr)); total += pr
if total >= p: break
return dict(renormalize(keep))
kdist = top_k(labels, base_probs, 3)
pdist = top_p(labels, base_probs, 0.8)
orig = base_probs
kvals = [kdist.get(l, 0.0) for l in labels]
pvals = [pdist.get(l, 0.0) for l in labels]
x = range(len(labels)); w = 0.26
plt.figure(figsize=(10,4))
plt.bar([i-w for i in x], orig, width=w, label='Original')
plt.bar(x, kvals, width=w, label='Top-k=3')
plt.bar([i+w for i in x], pvals, width=w, label='Top-p=0.8')
plt.xticks(list(x), labels)
plt.title('Distribution comparison')
plt.ylabel('Probability')
plt.legend()
plt.tight_layout()
plt.show()