|
| 1 | +import numpy as np |
| 2 | +from collections import Counter |
| 3 | +from sklearn.datasets import fetch_openml |
| 4 | +from skimage.transform import resize |
| 5 | +import matplotlib.pyplot as plt |
| 6 | +import warnings |
| 7 | +warnings.filterwarnings("ignore") |
| 8 | + |
| 9 | +# --- STEP 1: Load and preprocess MNIST zeros (4x4 binarized) --- |
| 10 | + |
| 11 | +print("Downloading and preprocessing MNIST...") |
| 12 | +mnist = fetch_openml("mnist_784", version=1, as_frame=False) |
| 13 | +X, y = mnist["data"], mnist["target"] |
| 14 | +X_zeros = X[y == '0'] / 255.0 |
| 15 | +X_zeros = X_zeros[:200] |
| 16 | + |
| 17 | +def downsample_binarize(img, size=4): |
| 18 | + img = img.reshape(28, 28) |
| 19 | + small = resize(img, (size, size), order=0, anti_aliasing=False, preserve_range=True) |
| 20 | + binary = (small > 0.5).astype(int) |
| 21 | + return ''.join(map(str, binary.flatten())) |
| 22 | + |
| 23 | +samples_bin = [downsample_binarize(img) for img in X_zeros] |
| 24 | +data_dist = Counter(samples_bin) |
| 25 | +total = sum(data_dist.values()) |
| 26 | +data_dist = {k: v / total for k, v in data_dist.items()} |
| 27 | + |
| 28 | +# --- STEP 2: Quantum Circuit Utils --- |
| 29 | + |
| 30 | +def Ry(theta): |
| 31 | + return np.array([ |
| 32 | + [np.cos(theta/2), -np.sin(theta/2)], |
| 33 | + [np.sin(theta/2), np.cos(theta/2)] |
| 34 | + ]) |
| 35 | + |
| 36 | +def CNOT(n, control, target): |
| 37 | + dim = 2**n |
| 38 | + op = np.zeros((dim, dim), dtype=complex) |
| 39 | + for i in range(dim): |
| 40 | + bits = list(np.binary_repr(i, width=n)) |
| 41 | + if bits[control] == '1': |
| 42 | + bits[target] = '1' if bits[target] == '0' else '0' |
| 43 | + j = int(''.join(bits), 2) |
| 44 | + op[i, j] = 1 |
| 45 | + return op |
| 46 | + |
| 47 | +def variational_state(params): |
| 48 | + n = len(params) |
| 49 | + state = np.zeros(2**n, dtype=complex) |
| 50 | + state[0] = 1 |
| 51 | + U = 1 |
| 52 | + for theta in params: |
| 53 | + U = np.kron(U, Ry(theta)) |
| 54 | + state = U @ state |
| 55 | + for i in range(n - 1): |
| 56 | + state = CNOT(n, i, i + 1) @ state |
| 57 | + return state |
| 58 | + |
| 59 | +def sample_state(psi, num_samples=1000): |
| 60 | + probs = np.abs(psi)**2 |
| 61 | + states = [format(i, f'0{int(np.log2(len(psi)))}b') for i in range(len(psi))] |
| 62 | + return np.random.choice(states, size=num_samples, p=probs) |
| 63 | + |
| 64 | +def get_prob_dist(samples): |
| 65 | + counts = Counter(samples) |
| 66 | + total = sum(counts.values()) |
| 67 | + return {x: c / total for x, c in counts.items()} |
| 68 | + |
| 69 | +# --- Contrastive Divergence Loss --- |
| 70 | + |
| 71 | +def energy(bitstring, psi): |
| 72 | + index = int(bitstring, 2) |
| 73 | + prob = np.abs(psi[index])**2 |
| 74 | + return -np.log(prob + 1e-10) |
| 75 | + |
| 76 | +def contrastive_divergence_loss(psi, data_samples, model_samples): |
| 77 | + E_data = np.mean([energy(x, psi) for x in data_samples]) |
| 78 | + E_model = np.mean([energy(x, psi) for x in model_samples]) |
| 79 | + return E_data - E_model |
| 80 | + |
| 81 | +def parameter_shift_grad_cd(params, data_samples, shift=np.pi/2, num_samples=500): |
| 82 | + grads = np.zeros_like(params) |
| 83 | + for i in range(len(params)): |
| 84 | + plus = params.copy() |
| 85 | + minus = params.copy() |
| 86 | + plus[i] += shift |
| 87 | + minus[i] -= shift |
| 88 | + |
| 89 | + psi_plus = variational_state(plus) |
| 90 | + psi_minus = variational_state(minus) |
| 91 | + |
| 92 | + model_plus = sample_state(psi_plus, num_samples) |
| 93 | + model_minus = sample_state(psi_minus, num_samples) |
| 94 | + |
| 95 | + loss_plus = contrastive_divergence_loss(psi_plus, data_samples, model_plus) |
| 96 | + loss_minus = contrastive_divergence_loss(psi_minus, data_samples, model_minus) |
| 97 | + |
| 98 | + grads[i] = 0.5 * (loss_plus - loss_minus) |
| 99 | + return grads |
| 100 | + |
| 101 | +# --- STEP 3: Training --- |
| 102 | + |
| 103 | +n_qubits = 4 |
| 104 | +params = np.random.uniform(0, 2*np.pi, size=n_qubits) |
| 105 | +lr = 0.1 |
| 106 | +data_samples = samples_bin[:100] |
| 107 | + |
| 108 | +print("\nTraining VQBM with Contrastive Divergence...\n") |
| 109 | +for step in range(100): |
| 110 | + psi = variational_state(params) |
| 111 | + model_samples = sample_state(psi, num_samples=500) |
| 112 | + loss = contrastive_divergence_loss(psi, data_samples, model_samples) |
| 113 | + |
| 114 | + grads = parameter_shift_grad_cd(params, data_samples) |
| 115 | + params -= lr * grads |
| 116 | + |
| 117 | + if step % 10 == 0: |
| 118 | + print(f"Step {step:3d}: CD Loss = {loss:.4f}") |
| 119 | + |
| 120 | +# --- STEP 4: Plot Results --- |
| 121 | + |
| 122 | +psi_final = variational_state(params) |
| 123 | +samples = sample_state(psi_final, 2000) |
| 124 | +model_dist = get_prob_dist(samples) |
| 125 | + |
| 126 | +# Top k states |
| 127 | +top_k = 10 |
| 128 | +all_states = list(set(list(data_dist.keys()) + list(model_dist.keys()))) |
| 129 | +top_states = sorted(all_states, key=lambda s: data_dist.get(s, 0) + model_dist.get(s, 0), reverse=True)[:top_k] |
| 130 | + |
| 131 | +x = np.arange(len(top_states)) |
| 132 | +data_vals = [data_dist.get(s, 0) for s in top_states] |
| 133 | +model_vals = [model_dist.get(s, 0) for s in top_states] |
| 134 | + |
| 135 | +plt.figure(figsize=(10, 5)) |
| 136 | +plt.bar(x - 0.2, data_vals, width=0.4, label="Data") |
| 137 | +plt.bar(x + 0.2, model_vals, width=0.4, label="Model") |
| 138 | +plt.xticks(x, top_states, rotation=45) |
| 139 | +plt.ylabel("Probability") |
| 140 | +plt.title("Top Learned Distributions: Data vs VQBM Model") |
| 141 | +plt.legend() |
| 142 | +plt.tight_layout() |
| 143 | +plt.show() |
0 commit comments