Skip to content

Commit c2c455b

Browse files
committed
Create qrbm5.py
1 parent 450f8cf commit c2c455b

File tree

1 file changed

+143
-0
lines changed

1 file changed

+143
-0
lines changed

doc/Programs/QuantumRBM/qrbm5.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import numpy as np
2+
from collections import Counter
3+
from sklearn.datasets import fetch_openml
4+
from skimage.transform import resize
5+
import matplotlib.pyplot as plt
6+
import warnings
7+
warnings.filterwarnings("ignore")
8+
9+
# --- STEP 1: Load and preprocess MNIST zeros (4x4 binarized) ---
10+
11+
print("Downloading and preprocessing MNIST...")
12+
mnist = fetch_openml("mnist_784", version=1, as_frame=False)
13+
X, y = mnist["data"], mnist["target"]
14+
X_zeros = X[y == '0'] / 255.0
15+
X_zeros = X_zeros[:200]
16+
17+
def downsample_binarize(img, size=4):
18+
img = img.reshape(28, 28)
19+
small = resize(img, (size, size), order=0, anti_aliasing=False, preserve_range=True)
20+
binary = (small > 0.5).astype(int)
21+
return ''.join(map(str, binary.flatten()))
22+
23+
samples_bin = [downsample_binarize(img) for img in X_zeros]
24+
data_dist = Counter(samples_bin)
25+
total = sum(data_dist.values())
26+
data_dist = {k: v / total for k, v in data_dist.items()}
27+
28+
# --- STEP 2: Quantum Circuit Utils ---
29+
30+
def Ry(theta):
31+
return np.array([
32+
[np.cos(theta/2), -np.sin(theta/2)],
33+
[np.sin(theta/2), np.cos(theta/2)]
34+
])
35+
36+
def CNOT(n, control, target):
37+
dim = 2**n
38+
op = np.zeros((dim, dim), dtype=complex)
39+
for i in range(dim):
40+
bits = list(np.binary_repr(i, width=n))
41+
if bits[control] == '1':
42+
bits[target] = '1' if bits[target] == '0' else '0'
43+
j = int(''.join(bits), 2)
44+
op[i, j] = 1
45+
return op
46+
47+
def variational_state(params):
48+
n = len(params)
49+
state = np.zeros(2**n, dtype=complex)
50+
state[0] = 1
51+
U = 1
52+
for theta in params:
53+
U = np.kron(U, Ry(theta))
54+
state = U @ state
55+
for i in range(n - 1):
56+
state = CNOT(n, i, i + 1) @ state
57+
return state
58+
59+
def sample_state(psi, num_samples=1000):
60+
probs = np.abs(psi)**2
61+
states = [format(i, f'0{int(np.log2(len(psi)))}b') for i in range(len(psi))]
62+
return np.random.choice(states, size=num_samples, p=probs)
63+
64+
def get_prob_dist(samples):
65+
counts = Counter(samples)
66+
total = sum(counts.values())
67+
return {x: c / total for x, c in counts.items()}
68+
69+
# --- Contrastive Divergence Loss ---
70+
71+
def energy(bitstring, psi):
72+
index = int(bitstring, 2)
73+
prob = np.abs(psi[index])**2
74+
return -np.log(prob + 1e-10)
75+
76+
def contrastive_divergence_loss(psi, data_samples, model_samples):
77+
E_data = np.mean([energy(x, psi) for x in data_samples])
78+
E_model = np.mean([energy(x, psi) for x in model_samples])
79+
return E_data - E_model
80+
81+
def parameter_shift_grad_cd(params, data_samples, shift=np.pi/2, num_samples=500):
82+
grads = np.zeros_like(params)
83+
for i in range(len(params)):
84+
plus = params.copy()
85+
minus = params.copy()
86+
plus[i] += shift
87+
minus[i] -= shift
88+
89+
psi_plus = variational_state(plus)
90+
psi_minus = variational_state(minus)
91+
92+
model_plus = sample_state(psi_plus, num_samples)
93+
model_minus = sample_state(psi_minus, num_samples)
94+
95+
loss_plus = contrastive_divergence_loss(psi_plus, data_samples, model_plus)
96+
loss_minus = contrastive_divergence_loss(psi_minus, data_samples, model_minus)
97+
98+
grads[i] = 0.5 * (loss_plus - loss_minus)
99+
return grads
100+
101+
# --- STEP 3: Training ---
102+
103+
n_qubits = 4
104+
params = np.random.uniform(0, 2*np.pi, size=n_qubits)
105+
lr = 0.1
106+
data_samples = samples_bin[:100]
107+
108+
print("\nTraining VQBM with Contrastive Divergence...\n")
109+
for step in range(100):
110+
psi = variational_state(params)
111+
model_samples = sample_state(psi, num_samples=500)
112+
loss = contrastive_divergence_loss(psi, data_samples, model_samples)
113+
114+
grads = parameter_shift_grad_cd(params, data_samples)
115+
params -= lr * grads
116+
117+
if step % 10 == 0:
118+
print(f"Step {step:3d}: CD Loss = {loss:.4f}")
119+
120+
# --- STEP 4: Plot Results ---
121+
122+
psi_final = variational_state(params)
123+
samples = sample_state(psi_final, 2000)
124+
model_dist = get_prob_dist(samples)
125+
126+
# Top k states
127+
top_k = 10
128+
all_states = list(set(list(data_dist.keys()) + list(model_dist.keys())))
129+
top_states = sorted(all_states, key=lambda s: data_dist.get(s, 0) + model_dist.get(s, 0), reverse=True)[:top_k]
130+
131+
x = np.arange(len(top_states))
132+
data_vals = [data_dist.get(s, 0) for s in top_states]
133+
model_vals = [model_dist.get(s, 0) for s in top_states]
134+
135+
plt.figure(figsize=(10, 5))
136+
plt.bar(x - 0.2, data_vals, width=0.4, label="Data")
137+
plt.bar(x + 0.2, model_vals, width=0.4, label="Model")
138+
plt.xticks(x, top_states, rotation=45)
139+
plt.ylabel("Probability")
140+
plt.title("Top Learned Distributions: Data vs VQBM Model")
141+
plt.legend()
142+
plt.tight_layout()
143+
plt.show()

0 commit comments

Comments
 (0)