Skip to content

Commit aaf221a

Browse files
add and optimize tebd examples
1 parent b6c0e69 commit aaf221a

File tree

2 files changed

+238
-35
lines changed

2 files changed

+238
-35
lines changed

examples/xyzmodel_tebd.py

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
import time
6+
from functools import partial
67
import numpy as np
78
import scipy
89
import tensorcircuit as tc
@@ -11,6 +12,38 @@
1112
tc.set_dtype("complex128")
1213

1314

15+
@partial(K.jit, static_argnums=(8, 9))
16+
def apply_trotter_step(
17+
mps_tensors, hxx, hyy, hzz, hx, hy, hz, dt_step, nqubits, bond_dim
18+
):
19+
split_rules = {"max_singular_values": bond_dim}
20+
mps_circuit = tc.MPSCircuit(nqubits, tensors=mps_tensors, split=split_rules)
21+
# Apply odd bonds (1-2, 3-4, ...)
22+
23+
for i in range(0, nqubits, 2):
24+
mps_circuit.rxx(i, (i + 1) % nqubits, theta=hxx * dt_step)
25+
mps_circuit.ryy(i, (i + 1) % nqubits, theta=hyy * dt_step)
26+
mps_circuit.rzz(i, (i + 1) % nqubits, theta=hzz * dt_step)
27+
28+
# Apply even bonds (2-3, 4-5, ...)
29+
for i in range(1, nqubits, 2):
30+
mps_circuit.rxx(i, (i + 1) % nqubits, theta=2 * hxx * dt_step)
31+
mps_circuit.ryy(i, (i + 1) % nqubits, theta=2 * hyy * dt_step)
32+
mps_circuit.rzz(i, (i + 1) % nqubits, theta=2 * hzz * dt_step)
33+
34+
for i in range(0, nqubits, 2):
35+
mps_circuit.rxx(i, (i + 1) % nqubits, theta=hxx * dt_step)
36+
mps_circuit.ryy(i, (i + 1) % nqubits, theta=hyy * dt_step)
37+
mps_circuit.rzz(i, (i + 1) % nqubits, theta=hzz * dt_step)
38+
39+
for i in range(nqubits):
40+
mps_circuit.rx(i, theta=2 * hx * dt_step)
41+
mps_circuit.ry(i, theta=2 * hy * dt_step)
42+
mps_circuit.rz(i, theta=2 * hz * dt_step)
43+
44+
return mps_circuit._mps.tensors
45+
46+
1447
def heisenberg_time_evolution_mps(
1548
nqubits: int,
1649
total_time: float,
@@ -35,39 +68,20 @@ def heisenberg_time_evolution_mps(
3568
nsteps = int(total_time / dt)
3669
dt_step = dt
3770

38-
@K.jit
39-
def apply_trotter_step(mps_tensors):
40-
mps_circuit = tc.MPSCircuit(nqubits, tensors=mps_tensors, split=split_rules)
41-
# Apply odd bonds (1-2, 3-4, ...)
42-
43-
for i in range(0, nqubits, 2):
44-
mps_circuit.rxx(i, (i + 1) % nqubits, theta=hxx * dt_step)
45-
mps_circuit.ryy(i, (i + 1) % nqubits, theta=hyy * dt_step)
46-
mps_circuit.rzz(i, (i + 1) % nqubits, theta=hzz * dt_step)
47-
48-
# Apply even bonds (2-3, 4-5, ...)
49-
for i in range(1, nqubits, 2):
50-
mps_circuit.rxx(i, (i + 1) % nqubits, theta=2 * hxx * dt_step)
51-
mps_circuit.ryy(i, (i + 1) % nqubits, theta=2 * hyy * dt_step)
52-
mps_circuit.rzz(i, (i + 1) % nqubits, theta=2 * hzz * dt_step)
53-
54-
# mps_circuit.unitary(i, unitary=unitary)
55-
56-
for i in range(0, nqubits, 2):
57-
mps_circuit.rxx(i, (i + 1) % nqubits, theta=hxx * dt_step)
58-
mps_circuit.ryy(i, (i + 1) % nqubits, theta=hyy * dt_step)
59-
mps_circuit.rzz(i, (i + 1) % nqubits, theta=hzz * dt_step)
60-
61-
for i in range(nqubits):
62-
mps_circuit.rx(i, theta=2 * hx * dt_step)
63-
mps_circuit.ry(i, theta=2 * hy * dt_step)
64-
mps_circuit.rz(i, theta=2 * hz * dt_step)
65-
66-
return mps_circuit._mps.tensors
67-
6871
# Perform time evolution
69-
for step in range(nsteps):
70-
tensors = apply_trotter_step(tensors)
72+
for _ in range(nsteps):
73+
tensors = apply_trotter_step(
74+
tensors,
75+
hxx,
76+
hyy,
77+
hzz,
78+
hx,
79+
hy,
80+
hz,
81+
dt_step,
82+
nqubits,
83+
split_rules["max_singular_values"],
84+
)
7185

7286
return tc.MPSCircuit(nqubits, tensors=tensors, split=split_rules)
7387

@@ -157,7 +171,7 @@ def benchmark_efficiency(nqubits, bond_d):
157171
split_rules=split_rules,
158172
)
159173
print(final_mps._mps.tensors[0])
160-
print("cold start run:", time.time() - time0)
174+
print("cold start run:", (time.time() - time0) / 10)
161175
time0 = time.time()
162176
final_mps = heisenberg_time_evolution_mps(
163177
nqubits=nqubits,
@@ -169,9 +183,9 @@ def benchmark_efficiency(nqubits, bond_d):
169183
split_rules=split_rules,
170184
)
171185
print(final_mps._mps.tensors[0])
172-
print("jitted run:", time.time() - time0)
186+
print("jitted run:", (time.time() - time0) / 10)
173187

174188

175189
if __name__ == "__main__":
176190
compare_baseline()
177-
benchmark_efficiency(24, 48)
191+
benchmark_efficiency(32, 48)

examples/xyzmodel_tebd_obc_imag.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
"""
2+
1D TEBD on imaginary time toward ground state with OBC
3+
"""
4+
5+
import time
6+
from functools import partial
7+
import numpy as np
8+
import scipy
9+
import tensorcircuit as tc
10+
11+
12+
K = tc.set_backend("jax")
13+
tc.set_dtype("complex128")
14+
15+
16+
@partial(K.jit, static_argnums=(8, 9))
17+
def apply_trotter_step(
18+
mps_tensors, hxx, hyy, hzz, hx, hy, hz, dt_step, nqubits, bond_dim
19+
):
20+
split_rules = {"max_singular_values": bond_dim}
21+
mps_circuit = tc.MPSCircuit(nqubits, tensors=mps_tensors, split=split_rules)
22+
# Apply odd bonds (1-2, 3-4, ...)
23+
24+
for i in range(0, nqubits - 1, 2):
25+
mps_circuit.rxx(i, (i + 1), theta=-2.0j * hxx * dt_step)
26+
mps_circuit.ryy(i, (i + 1), theta=-2.0j * hyy * dt_step)
27+
mps_circuit.rzz(i, (i + 1), theta=-2.0j * hzz * dt_step)
28+
mps_circuit._mps.position(nqubits - 1, normalize=True)
29+
30+
# Apply even bonds (2-3, 4-5, ...)
31+
for i in reversed(range(1, nqubits - 1, 2)):
32+
mps_circuit.rxx(i, (i + 1), theta=-2.0j * hxx * dt_step)
33+
mps_circuit.ryy(i, (i + 1), theta=-2.0j * hyy * dt_step)
34+
mps_circuit.rzz(i, (i + 1), theta=-2.0j * hzz * dt_step)
35+
mps_circuit._mps.position(0, normalize=True)
36+
37+
for i in range(nqubits):
38+
mps_circuit.rx(i, theta=-2.0j * hx * dt_step)
39+
mps_circuit.ry(i, theta=-2.0j * hy * dt_step)
40+
mps_circuit.rz(i, theta=-2.0j * hz * dt_step)
41+
# mps_circuit._mps.position(0, normalize=True)
42+
43+
return mps_circuit._mps.tensors
44+
45+
46+
def heisenberg_imag_time_evolution_mps(
47+
nqubits: int,
48+
total_time: float,
49+
dt: float,
50+
hxx: float = 1.0,
51+
hyy: float = 1.0,
52+
hzz: float = 1.0,
53+
hz: float = 0.0,
54+
hx: float = 0.0,
55+
hy: float = 0.0,
56+
initial_state=None,
57+
split_rules=None,
58+
):
59+
60+
# Initialize MPS circuit
61+
if initial_state is not None:
62+
mps = tc.MPSCircuit(nqubits, wavefunction=initial_state, split=split_rules)
63+
else:
64+
mps = tc.MPSCircuit(nqubits, split=split_rules)
65+
tensors = mps._mps.tensors
66+
67+
# Number of Trotter steps
68+
nsteps = int(total_time / dt)
69+
dt_step = dt
70+
71+
# Perform time evolution
72+
for _ in range(nsteps):
73+
tensors = apply_trotter_step(
74+
tensors,
75+
hxx,
76+
hyy,
77+
hzz,
78+
hx,
79+
hy,
80+
hz,
81+
dt_step,
82+
nqubits,
83+
split_rules["max_singular_values"],
84+
)
85+
86+
return tc.MPSCircuit(nqubits, tensors=tensors, split=split_rules)
87+
88+
89+
def compare_baseline(nqubits=12):
90+
# Parameters
91+
total_time = 6
92+
dt = 0.02
93+
94+
# Heisenberg parameters
95+
hxx = 0.8
96+
hyy = 1.0
97+
hzz = 2.0
98+
hz = 0.01
99+
hy = 0.0
100+
hx = 0.0
101+
102+
split_rules = {"max_singular_values": 24}
103+
104+
c = tc.Circuit(nqubits)
105+
c.x([2 * i for i in range(nqubits // 2)])
106+
initial_state = c.state()
107+
108+
# TEBD evolution
109+
final_mps = heisenberg_imag_time_evolution_mps(
110+
nqubits=nqubits,
111+
total_time=total_time,
112+
dt=dt,
113+
hxx=hxx,
114+
hyy=hyy,
115+
hzz=hzz,
116+
hz=hz,
117+
hx=hx,
118+
hy=hy,
119+
initial_state=initial_state,
120+
split_rules=split_rules,
121+
)
122+
# Exact evolution
123+
g = tc.templates.graphs.Line1D(nqubits, pbc=False)
124+
H = tc.quantum.heisenberg_hamiltonian(
125+
g, hxx=hxx, hyy=hyy, hzz=hzz, hz=hz, hy=hy, hx=hx, sparse=False
126+
)
127+
U = scipy.linalg.expm(-total_time * H)
128+
exact_final = K.reshape(U @ K.reshape(initial_state, [-1, 1]), [-1])
129+
exact_final /= K.norm(exact_final)
130+
# Compare results
131+
mps_state = final_mps.wavefunction()
132+
fidelity = np.abs(np.vdot(exact_final, mps_state)) ** 2
133+
print(f"Fidelity between TEBD and exact evolution: {fidelity}")
134+
c_exact = tc.Circuit(nqubits, inputs=exact_final)
135+
# Measure observables
136+
z_magnetization_mps = []
137+
z_magnetization_exact = []
138+
for i in range(nqubits):
139+
mag_mps = final_mps.expectation((tc.gates.z(), [i]))
140+
z_magnetization_mps.append(mag_mps)
141+
142+
mag_exact = c_exact.expectation((tc.gates.z(), [i]))
143+
z_magnetization_exact.append(mag_exact)
144+
145+
print("MPS Z magnetization:", K.stack(z_magnetization_mps))
146+
print("Exact Z magnetization:", K.stack(z_magnetization_exact))
147+
print("Final bond dimensions:", final_mps.get_bond_dimensions())
148+
149+
return final_mps, exact_final
150+
151+
152+
def benchmark_efficiency(nqubits, bond_d):
153+
total_time = 0.2
154+
dt = 0.01
155+
hxx = 0.9
156+
hyy = 1.0
157+
hzz = 0.3
158+
split_rules = {"max_singular_values": bond_d}
159+
160+
# TEBD evolution
161+
time0 = time.time()
162+
final_mps = heisenberg_imag_time_evolution_mps(
163+
nqubits=nqubits,
164+
total_time=total_time,
165+
dt=dt,
166+
hxx=hxx,
167+
hyy=hyy,
168+
hzz=hzz,
169+
split_rules=split_rules,
170+
)
171+
print(final_mps._mps.tensors[0])
172+
print("1 step cold start run:", (time.time() - time0) / 20)
173+
time0 = time.time()
174+
final_mps = heisenberg_imag_time_evolution_mps(
175+
nqubits=nqubits,
176+
total_time=total_time,
177+
dt=dt,
178+
hxx=hxx,
179+
hyy=hyy,
180+
hzz=hzz,
181+
split_rules=split_rules,
182+
)
183+
print(final_mps._mps.tensors[0])
184+
print("1 step jitted run:", (time.time() - time0) / 20)
185+
186+
187+
if __name__ == "__main__":
188+
compare_baseline()
189+
benchmark_efficiency(32, 32)

0 commit comments

Comments
 (0)