Skip to content

Commit cbb0907

Browse files
add stabilizercircuit
1 parent 262e661 commit cbb0907

File tree

5 files changed

+448
-0
lines changed

5 files changed

+448
-0
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
### Added
66

7+
- Add initial draft on stabilizer simulation using stim backend
8+
79
- Add `jax_interface`
810

911
- Add `merge_count` in `results` module

requirements/requirements-extra.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ torch==2.2.2
88
# jupyter
99
mthree==1.1.0
1010
openfermion
11+
stim
1112
pylatexenc

tensorcircuit/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@
3131
from .densitymatrix import DMCircuit2
3232

3333
DMCircuit = DMCircuit2 # compatibility issue to still expose DMCircuit2
34+
35+
try:
36+
from .stabilizercircuit import StabilizerCircuit
37+
except ModuleNotFoundError:
38+
pass
39+
3440
from .gates import num_to_tensor, array_to_tensor
3541
from .vis import qir2tex, render_pdf
3642
from . import interfaces

tensorcircuit/stabilizercircuit.py

Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
"""
2+
Stabilizer circuit simulator using Stim backend
3+
"""
4+
5+
from typing import Any, Dict, List, Optional, Sequence
6+
import numpy as np
7+
import stim
8+
9+
from .abstractcircuit import AbstractCircuit
10+
11+
Tensor = Any
12+
13+
14+
class StabilizerCircuit(AbstractCircuit):
15+
"""
16+
Quantum circuit simulator for stabilizer circuits using Stim backend.
17+
Supports Clifford operations and measurements.
18+
"""
19+
20+
# Add gate sets as class attributes
21+
clifford_gates = ["h", "x", "y", "z", "cnot", "cz", "swap", "s", "sd"]
22+
23+
def __init__(self, nqubits: int) -> None:
24+
self._nqubits = nqubits
25+
self._stim_circuit = stim.Circuit()
26+
self._qir: List[Dict[str, Any]] = []
27+
self.is_dm = False
28+
self.inputs = None
29+
self._extra_qir: List[Dict[str, Any]] = []
30+
self.current_sim = stim.TableauSimulator()
31+
32+
def apply_general_gate(
33+
self,
34+
gate: Any,
35+
*index: int,
36+
name: Optional[str] = None,
37+
**kws: Any,
38+
) -> None:
39+
"""
40+
Apply a Clifford gate to the circuit.
41+
42+
:param gate: Gate to apply (must be Clifford)
43+
:type gate: Any
44+
:param index: Qubit indices to apply the gate to
45+
:type index: int
46+
:param name: Name of the gate operation, defaults to None
47+
:type name: Optional[str], optional
48+
:raises ValueError: If non-Clifford gate is applied
49+
"""
50+
if name is None:
51+
name = ""
52+
53+
# Record gate in QIR
54+
gate_dict = {
55+
"gate": gate,
56+
"index": index,
57+
"name": name,
58+
"split": kws.get("split", False),
59+
"mpo": kws.get("mpo", False),
60+
}
61+
ir_dict = kws.get("ir_dict", None)
62+
if ir_dict is not None:
63+
ir_dict.update(gate_dict)
64+
else:
65+
ir_dict = gate_dict
66+
self._qir.append(ir_dict)
67+
68+
# Convert negative indices
69+
index = tuple([i if i >= 0 else self._nqubits + i for i in index])
70+
71+
# Map TensorCircuit gates to Stim gates
72+
gate_map = {
73+
"h": "H",
74+
"x": "X",
75+
"y": "Y",
76+
"z": "Z",
77+
"cnot": "CNOT",
78+
"cz": "CZ",
79+
"swap": "SWAP",
80+
"s": "S",
81+
"sd": "S_DAG",
82+
}
83+
if name.lower() in gate_map:
84+
self._stim_circuit.append(gate_map[name.lower()], list(index))
85+
instruction = stim.Circuit()
86+
instruction.append(gate_map[name.lower()], list(index))
87+
self.current_sim.do(instruction)
88+
else:
89+
raise ValueError(f"Gate {name} is not supported in stabilizer simulation")
90+
91+
apply = apply_general_gate
92+
93+
def measure(self, *index: int, with_prob: bool = False) -> Tensor:
94+
"""
95+
Measure qubits in Z basis.
96+
97+
:param index: Indices of qubits to measure
98+
:type index: int
99+
:param with_prob: Return probability of measurement outcome, defaults to False
100+
:type with_prob: bool, optional
101+
:return: Measurement results and probability (if with_prob=True)
102+
:rtype: Union[np.ndarray, Tuple[np.ndarray, float]]
103+
"""
104+
# Convert negative indices
105+
106+
index = tuple([i for i in index if i >= 0])
107+
108+
# Add measurement instructions
109+
s1 = self.current_simulator().copy()
110+
m = s1.measure_many(*index)
111+
# Sample once from the circuit using sampler
112+
113+
# TODO(@refraction-ray): correct probability
114+
return m
115+
116+
def cond_measurement(self, index: int) -> Tensor:
117+
"""
118+
Measure qubits in Z basis with state collapse.
119+
120+
:param index: Index of qubit to measure
121+
:type index: int
122+
:return: Measurement results and probability (if with_prob=True)
123+
:rtype: Union[np.ndarray, Tuple[np.ndarray, float]]
124+
"""
125+
# Convert negative indices
126+
127+
# Add measurement instructions
128+
self._stim_circuit.append("M", index)
129+
m = self.current_simulator().measure(index)
130+
# Sample once from the circuit using sampler
131+
132+
return m
133+
134+
cond_measure = cond_measurement
135+
136+
def sample(
137+
self,
138+
batch: Optional[int] = None,
139+
**kws: Any,
140+
) -> Tensor:
141+
"""
142+
Sample measurements from the circuit.
143+
144+
:param batch: Number of samples to take, defaults to None (single sample)
145+
:type batch: Optional[int], optional
146+
:return: Measurement results
147+
:rtype: Tensor
148+
"""
149+
if batch is None:
150+
batch = 1
151+
c = self.current_circuit().copy()
152+
for i in range(self._nqubits):
153+
c.append("M", [i])
154+
sampler = c.compile_sampler()
155+
samples = sampler.sample(batch)
156+
return np.array(samples)
157+
158+
def expectation_ps( # type: ignore
159+
self,
160+
x: Optional[Sequence[int]] = None,
161+
y: Optional[Sequence[int]] = None,
162+
z: Optional[Sequence[int]] = None,
163+
**kws: Any,
164+
) -> Any:
165+
"""
166+
Compute exact expectation value of Pauli string using stim's direct calculation.
167+
168+
:param x: Indices for Pauli X measurements
169+
:type x: Optional[Sequence[int]], optional
170+
:param y: Indices for Pauli Y measurements
171+
:type y: Optional[Sequence[int]], optional
172+
:param z: Indices for Pauli Z measurements
173+
:type z: Optional[Sequence[int]], optional
174+
:return: Expectation value
175+
:rtype: float
176+
"""
177+
# Build Pauli string representation
178+
pauli_str = ["I"] * self._nqubits
179+
180+
if x:
181+
for i in x:
182+
pauli_str[i] = "X"
183+
if y:
184+
for i in y:
185+
pauli_str[i] = "Y"
186+
if z:
187+
for i in z:
188+
pauli_str[i] = "Z"
189+
190+
pauli_string = "".join(pauli_str)
191+
# Calculate expectation using stim's direct method
192+
expectation = self.current_simulator().peek_observable_expectation(
193+
stim.PauliString(pauli_string)
194+
)
195+
return expectation
196+
197+
expps = expectation_ps
198+
199+
def sample_expectation_ps(
200+
self,
201+
x: Optional[Sequence[int]] = None,
202+
y: Optional[Sequence[int]] = None,
203+
z: Optional[Sequence[int]] = None,
204+
shots: Optional[int] = None,
205+
**kws: Any,
206+
) -> float:
207+
"""
208+
Compute expectation value of Pauli string measurements.
209+
210+
:param x: Indices for Pauli X measurements, defaults to None
211+
:type x: Optional[Sequence[int]], optional
212+
:param y: Indices for Pauli Y measurements, defaults to None
213+
:type y: Optional[Sequence[int]], optional
214+
:param z: Indices for Pauli Z measurements, defaults to None
215+
:type z: Optional[Sequence[int]], optional
216+
:param shots: Number of measurement shots, defaults to None
217+
:type shots: Optional[int], optional
218+
:return: Expectation value
219+
:rtype: float
220+
"""
221+
if shots is None:
222+
shots = 1000 # Default number of shots
223+
224+
circuit = self._stim_circuit.copy()
225+
226+
# Add basis rotations for measurements
227+
if x:
228+
for i in x:
229+
circuit.append("H", [i])
230+
if y:
231+
for i in y:
232+
circuit.append("S_DAG", [i])
233+
circuit.append("H", [i])
234+
235+
# Add measurements
236+
measured_qubits: List[int] = []
237+
if x:
238+
measured_qubits.extend(x)
239+
if y:
240+
measured_qubits.extend(y)
241+
if z:
242+
measured_qubits.extend(z)
243+
244+
for i in measured_qubits:
245+
circuit.append("M", [i])
246+
247+
# Sample and compute expectation using sampler
248+
sampler = circuit.compile_sampler()
249+
samples = sampler.sample(shots)
250+
results = np.array(samples)
251+
252+
# Convert from {0,1} to {1,-1}
253+
results = 1 - 2 * results
254+
255+
# Average over shots
256+
expectation = np.mean(np.prod(results, axis=1))
257+
258+
return float(expectation)
259+
260+
sexpps = sample_expectation_ps
261+
262+
def mid_measurement(self, index: int, keep: int = 0) -> Tensor:
263+
"""
264+
Perform a mid-measurement operation on a qubit on z direction.
265+
The post-selection cannot be recorded in ``stim.Circuit``
266+
267+
:param index: Index of the qubit to measure
268+
:type index: int
269+
:param keep: State of qubits to keep after measurement, defaults to 0 (up)
270+
:type keep: int, optional
271+
:return: Result of the mid-measurement operation
272+
:rtype: Tensor
273+
"""
274+
if keep not in [0, 1]:
275+
raise ValueError("keep must be 0 or 1")
276+
277+
self.current_sim.postselect_z(index, desired_value=keep)
278+
279+
mid_measure = mid_measurement
280+
post_select = mid_measurement
281+
post_selection = mid_measurement
282+
283+
def current_simulator(self) -> stim.TableauSimulator:
284+
"""
285+
Return the current simulator of the circuit.
286+
"""
287+
return self.current_sim
288+
289+
def current_circuit(self) -> stim.Circuit:
290+
"""
291+
Return the current stim circuit representation of the circuit.
292+
"""
293+
return self._stim_circuit
294+
295+
def current_tableau(self) -> stim.Tableau:
296+
"""
297+
Return the current tableau of the circuit.
298+
"""
299+
self.current_simulator().current_inverse_tableau() ** -1
300+
301+
302+
# Call _meta_apply at module level to register the gates
303+
StabilizerCircuit._meta_apply()

0 commit comments

Comments
 (0)