Skip to content

Commit ec8c10a

Browse files
add entanglement for stabilizer circuit
1 parent cbb0907 commit ec8c10a

File tree

5 files changed

+166
-27
lines changed

5 files changed

+166
-27
lines changed

examples/apicomparison/2_tc_qml.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def yp(img, params):
1616
return K.real(c.expectation_ps(z=[nwires - 1]))
1717

1818

19-
model = tc.keras.QuantumLayer(yp, [(nlayer * 2 * nwires)])
19+
model = tc.keras.QuantumLayer(yp, [nlayer * 2 * nwires])
2020

2121
imgs = K.implicit_randn(shape=[nbatch, nwires])
2222

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""
2+
Stabilizer Circuit Benchmark
3+
"""
4+
5+
from time import time
6+
import numpy as np
7+
import tensorcircuit as tc
8+
9+
tc.set_dtype("complex128")
10+
11+
clifford_one_qubit_gates = ["H", "X", "Y", "Z", "S", "SD"]
12+
clifford_two_qubit_gates = ["CNOT", "CZ", "SWAP"]
13+
clifford_gates = clifford_one_qubit_gates + clifford_two_qubit_gates
14+
15+
16+
def genpair(num_qubits, count):
17+
choice = list(range(num_qubits))
18+
for _ in range(count):
19+
np.random.shuffle(choice)
20+
x, y = choice[:2]
21+
yield (x, y)
22+
23+
24+
def random_clifford_circuit(num_qubits, depth):
25+
c = tc.Circuit(num_qubits)
26+
for _ in range(depth):
27+
for j, k in genpair(num_qubits, num_qubits // 2):
28+
gate_name = np.random.choice(clifford_two_qubit_gates)
29+
getattr(c, gate_name)(j, k)
30+
for j in range(num_qubits):
31+
gate_name = np.random.choice(clifford_one_qubit_gates)
32+
getattr(c, gate_name)(j)
33+
return c
34+
35+
36+
if __name__ == "__main__":
37+
time_cir_gen = 0
38+
time_cir_state = 0
39+
time_cir_ee = 0
40+
time_scir_gen = 0
41+
time_scir_ee = 0
42+
for _ in range(30):
43+
t0 = time()
44+
c = random_clifford_circuit(10, 15)
45+
time_cir_gen += time() - t0
46+
t0 = time()
47+
s = c.state()
48+
time_cir_state += time() - t0
49+
t0 = time()
50+
ee0 = tc.quantum.entanglement_entropy(s, list(range(5)))
51+
time_cir_ee += time() - t0
52+
t0 = time()
53+
c1 = tc.StabilizerCircuit.from_qir(c.to_qir())
54+
time_scir_gen += time() - t0
55+
t0 = time()
56+
ee1 = c1.entanglement_entropy(list(range(5)))
57+
time_scir_ee += time() - t0
58+
np.testing.assert_allclose(
59+
ee0,
60+
ee1,
61+
atol=1e-6,
62+
)
63+
64+
print("time_cir_gen", time_cir_gen / 30)
65+
print("time_cir_state", time_cir_state / 30)
66+
print("time_cir_ee", time_cir_ee / 30)
67+
print("time_scir_gen", time_scir_gen / 30)
68+
print("time_scir_ee", time_scir_ee / 30)

examples/stabilizer_simulation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""
2+
Deprecated, please see stahilizer_entanglement_benchmark.py
3+
"""
4+
15
import numpy as np
26
import stim
37

tensorcircuit/densitymatrix.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def apply_general_kraus(
227227

228228
@staticmethod
229229
def apply_general_kraus_delayed(
230-
krausf: Callable[..., Sequence[Gate]]
230+
krausf: Callable[..., Sequence[Gate]],
231231
) -> Callable[..., None]:
232232
def apply(self: "DMCircuit", *index: int, **vars: float) -> None:
233233
for key in ["status", "name"]:
@@ -391,7 +391,7 @@ def apply_general_kraus(
391391

392392
@staticmethod
393393
def apply_general_kraus_delayed(
394-
krausf: Callable[..., Sequence[Gate]]
394+
krausf: Callable[..., Sequence[Gate]],
395395
) -> Callable[..., None]:
396396
def apply(self: "DMCircuit2", *index: int, **vars: float) -> None:
397397
for key in ["status", "name"]:

tensorcircuit/stabilizercircuit.py

Lines changed: 91 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,17 @@ class StabilizerCircuit(AbstractCircuit):
1919

2020
# Add gate sets as class attributes
2121
clifford_gates = ["h", "x", "y", "z", "cnot", "cz", "swap", "s", "sd"]
22+
gate_map = {
23+
"h": "H",
24+
"x": "X",
25+
"y": "Y",
26+
"z": "Z",
27+
"cnot": "CNOT",
28+
"cz": "CZ",
29+
"swap": "SWAP",
30+
"s": "S",
31+
"sd": "S_DAG",
32+
}
2233

2334
def __init__(self, nqubits: int) -> None:
2435
self._nqubits = nqubits
@@ -55,36 +66,24 @@ def apply_general_gate(
5566
"gate": gate,
5667
"index": index,
5768
"name": name,
58-
"split": kws.get("split", False),
59-
"mpo": kws.get("mpo", False),
69+
"split": False,
70+
"mpo": False,
6071
}
61-
ir_dict = kws.get("ir_dict", None)
72+
ir_dict = kws["ir_dict"]
6273
if ir_dict is not None:
6374
ir_dict.update(gate_dict)
6475
else:
6576
ir_dict = gate_dict
6677
self._qir.append(ir_dict)
6778

68-
# Convert negative indices
69-
index = tuple([i if i >= 0 else self._nqubits + i for i in index])
70-
7179
# Map TensorCircuit gates to Stim gates
72-
gate_map = {
73-
"h": "H",
74-
"x": "X",
75-
"y": "Y",
76-
"z": "Z",
77-
"cnot": "CNOT",
78-
"cz": "CZ",
79-
"swap": "SWAP",
80-
"s": "S",
81-
"sd": "S_DAG",
82-
}
83-
if name.lower() in gate_map:
84-
self._stim_circuit.append(gate_map[name.lower()], list(index))
85-
instruction = stim.Circuit()
86-
instruction.append(gate_map[name.lower()], list(index))
87-
self.current_sim.do(instruction)
80+
81+
if name.lower() in self.gate_map:
82+
# self._stim_circuit.append(gate_map[name.lower()], list(index))
83+
instruction = f"{self.gate_map[name.lower()]} {' '.join(map(str, index))}"
84+
self._stim_circuit.append_from_stim_program_text(instruction)
85+
# append is much slower
86+
self.current_sim.do(stim.Circuit(instruction))
8887
else:
8988
raise ValueError(f"Gate {name} is not supported in stabilizer simulation")
9089

@@ -125,7 +124,8 @@ def cond_measurement(self, index: int) -> Tensor:
125124
# Convert negative indices
126125

127126
# Add measurement instructions
128-
self._stim_circuit.append("M", index)
127+
self._stim_circuit.append_from_stim_program_text("M " + str(index))
128+
# self.current_sim = None
129129
m = self.current_simulator().measure(index)
130130
# Sample once from the circuit using sampler
131131

@@ -296,7 +296,74 @@ def current_tableau(self) -> stim.Tableau:
296296
"""
297297
Return the current tableau of the circuit.
298298
"""
299-
self.current_simulator().current_inverse_tableau() ** -1
299+
return self.current_simulator().current_inverse_tableau() ** -1
300+
301+
def entanglement_entropy(self, cut: Sequence[int]) -> float:
302+
"""
303+
Calculate the entanglement entropy for a subset of qubits using stabilizer formalism.
304+
305+
:param cut: Indices of qubits to calculate entanglement entropy for
306+
:type cut: Sequence[int]
307+
:return: Entanglement entropy
308+
:rtype: float
309+
"""
310+
# Get stabilizer tableau
311+
tableau = self.current_tableau()
312+
N = len(tableau)
313+
314+
# Pre-allocate binary matrix with proper dtype
315+
# binary_matrix = np.zeros((N, 2 * N), dtype=np.int8)
316+
317+
# Vectorized conversion of stabilizers to binary matrix
318+
# z_outputs = np.array([tableau.z_output(k) for k in range(N)])
319+
# x_part = z_outputs == 1 # X
320+
# z_part = z_outputs == 3 # Z
321+
# y_part = z_outputs == 2 # Y
322+
323+
# binary_matrix[:, :N] = x_part | y_part
324+
# binary_matrix[:, N:] = z_part | y_part
325+
326+
_, _, z2x, z2z, _, _ = tableau.to_numpy()
327+
binary_matrix = np.concatenate([z2x, z2z], axis=1)
328+
# Get reduced matrix for the cut using boolean indexing
329+
cut_set = set(cut)
330+
cut_indices = np.array(
331+
[i for i in range(N) if i in cut_set]
332+
+ [i + N for i in range(N) if i in cut_set]
333+
)
334+
reduced_matrix = binary_matrix[:, cut_indices]
335+
336+
# Efficient rank calculation using Gaussian elimination
337+
matrix = reduced_matrix.copy()
338+
n_rows, n_cols = matrix.shape
339+
rank = 0
340+
row = 0
341+
342+
for col in range(n_cols):
343+
# Vectorized pivot finding
344+
pivot_rows = np.nonzero(matrix[row:, col])[0]
345+
if len(pivot_rows) > 0:
346+
pivot_row = pivot_rows[0] + row
347+
348+
# Swap rows if necessary
349+
if pivot_row != row:
350+
matrix[row], matrix[pivot_row] = (
351+
matrix[pivot_row].copy(),
352+
matrix[row].copy(),
353+
)
354+
355+
# Vectorized elimination
356+
eliminate_mask = matrix[row + 1 :, col] == 1
357+
matrix[row + 1 :][eliminate_mask] ^= matrix[row]
358+
359+
rank += 1
360+
row += 1
361+
362+
if row == n_rows:
363+
break
364+
365+
# Calculate entropy
366+
return float((rank - len(cut)) * np.log(2))
300367

301368

302369
# Call _meta_apply at module level to register the gates

0 commit comments

Comments
 (0)