Skip to content

Commit 16b880c

Browse files
improve pmap mpo example
1 parent c8084a4 commit 16b880c

File tree

1 file changed

+28
-29
lines changed

1 file changed

+28
-29
lines changed

examples/slicing_auto_pmap_mpo.py

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
tc.set_dtype("complex128")
2323

2424

25-
def get_circuit(n, d, params):
25+
def circuit2nodes(n, d, params, tc_mpo):
2626
c = tc.Circuit(n)
2727
c.h(range(n))
2828
for i in range(d):
@@ -32,14 +32,15 @@ def get_circuit(n, d, params):
3232
c.rx(j, theta=params[j, i, 1])
3333
for j in range(n):
3434
c.ry(j, theta=params[j, i, 2])
35-
return c
3635

37-
38-
def core(params, i, tree, n, d, tc_mpo):
39-
c = get_circuit(n, d, params)
4036
mps = c.get_quvector()
4137
e = mps.adjoint() @ tc_mpo @ mps
42-
_, nodes = tc.cons.get_tn_info(e.nodes)
38+
return e.nodes
39+
40+
41+
def core(params, i, tree, n, d, tc_mpo):
42+
nodes = circuit2nodes(n, d, params, tc_mpo)
43+
_, nodes = tc.cons.get_tn_info(nodes)
4344
input_arrays = [node.tensor for node in nodes]
4445
sliced_arrays = tree.slice_arrays(input_arrays, i)
4546
return K.real(tree.contract_core(sliced_arrays, backend=backend))[0, 0]
@@ -52,24 +53,18 @@ def core(params, i, tree, n, d, tc_mpo):
5253
nqubit = 12
5354
d = 6
5455

55-
Jx = jax.numpy.array([1.0] * (nqubit - 1)) # XX coupling strength
56-
Bz = jax.numpy.array([-1.0] * nqubit) # Transverse field strength
57-
58-
# Create TensorNetwork MPO
59-
tn_mpo = tn.matrixproductstates.mpo.FiniteTFI(Jx, Bz, dtype=np.complex64)
60-
tc_mpo = tc.quantum.tn2qop(tn_mpo)
61-
6256
# baseline results
6357
lattice = tc.templates.graphs.Line1D(nqubit, pbc=False)
6458
h = tc.quantum.heisenberg_hamiltonian(lattice, hzz=0, hyy=0, hxx=1.0, hz=-1.0)
6559
es0 = scipy.sparse.linalg.eigsh(K.numpy(h), k=1, which="SA")[0]
6660
print("exact ground state energy: ", es0)
6761

68-
params = K.implicit_randn(stddev=0.1, shape=[1, nqubit, d, 3], dtype=tc.rdtypestr)
69-
params = K.tile(params, [num_device, 1, 1, 1])
62+
params = K.implicit_randn(stddev=0.1, shape=[nqubit, d, 3], dtype=tc.rdtypestr)
63+
replicated_params = K.reshape(params, [1] + list(params.shape))
64+
replicated_params = K.tile(replicated_params, [num_device, 1, 1, 1])
7065

7166
optimizer = optax.adam(5e-2)
72-
base_opt_state = optimizer.init(params[0])
67+
base_opt_state = optimizer.init(params)
7368
replicated_opt_state = jax.tree.map(
7469
lambda x: (
7570
jax.numpy.broadcast_to(x, (num_device,) + x.shape)
@@ -93,28 +88,32 @@ def para_vag(params, i, tree, n, d, tc_mpo, opt_state):
9388
params = optax.apply_updates(params, updates)
9489
return params, opt_state, loss
9590

96-
c = get_circuit(nqubit, d, params[0])
97-
mps = c.get_quvector()
98-
e = mps.adjoint() @ tc_mpo @ mps
99-
tn_info, nodes = tc.cons.get_tn_info(e.nodes)
91+
Jx = jax.numpy.array([1.0] * (nqubit - 1)) # XX coupling strength
92+
Bz = jax.numpy.array([-1.0] * nqubit) # Transverse field strength
93+
# Create TensorNetwork MPO
94+
tn_mpo = tn.matrixproductstates.mpo.FiniteTFI(Jx, Bz, dtype=np.complex64)
95+
tc_mpo = tc.quantum.tn2qop(tn_mpo)
10096

97+
nodes = circuit2nodes(nqubit, d, params, tc_mpo)
98+
tn_info, _ = tc.cons.get_tn_info(nodes)
99+
100+
# Create ReusableHyperOptimizer for finding optimal contraction paths
101101
opt = ctg.ReusableHyperOptimizer(
102-
parallel=True,
102+
parallel=True, # Enable parallel path finding
103103
slicing_opts={
104-
"target_slices": num_device,
105-
# "target_size": 2**20, # Add memory target
104+
"target_slices": num_device, # Split computation across available devices
105+
# "target_size": 2**20, # Optional: Set memory limit per slice
106106
},
107-
max_repeats=256,
108-
progbar=True,
109-
minimize="combo",
107+
max_repeats=256, # Maximum number of path finding attempts
108+
progbar=True, # Show progress bar during optimization
109+
minimize="combo", # Optimize for both time and memory
110110
)
111-
112111
tree = opt.search(*tn_info)
113112

114113
inds = K.arange(num_device)
115114
for j in range(100):
116115
print(f"training loop: {j}-step")
117-
params, replicated_opt_state, loss = para_vag(
118-
params, inds, tree, nqubit, d, tc_mpo, replicated_opt_state
116+
replicated_params, replicated_opt_state, loss = para_vag(
117+
replicated_params, inds, tree, nqubit, d, tc_mpo, replicated_opt_state
119118
)
120119
print(loss[0])

0 commit comments

Comments
 (0)