Skip to content

Commit 8cd5b0a

Browse files
enable distribution infrastructure with great changes in contractor and quantum operators
1 parent e678a92 commit 8cd5b0a

File tree

6 files changed

+347
-29
lines changed

6 files changed

+347
-29
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010

1111
- Add `merge_count` in `results` module
1212

13+
### Fixed
14+
15+
- Better contractor infrastructures with "breakpoint" contractor to simply get the networks
16+
1317
## 1.1.0
1418

1519
### Added

examples/slicing_auto_pmap_mpo.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
"""
2+
This script illustrates how to parallelize both the contraction path
3+
finding and sliced contraction computation
4+
"""
5+
6+
from functools import partial
7+
import os
8+
9+
num_device = 4
10+
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={num_device}"
11+
import cotengra as ctg
12+
import tensornetwork as tn
13+
14+
import numpy as np
15+
import scipy
16+
import jax
17+
import optax
18+
import tensorcircuit as tc
19+
20+
backend = "jax"
21+
K = tc.set_backend(backend)
22+
tc.set_dtype("complex128")
23+
24+
25+
def get_circuit(n, d, params):
26+
c = tc.Circuit(n)
27+
c.h(range(n))
28+
for i in range(d):
29+
for j in range(0, n - 1):
30+
c.rzz(j, j + 1, theta=params[j, i, 0])
31+
for j in range(n):
32+
c.rx(j, theta=params[j, i, 1])
33+
for j in range(n):
34+
c.ry(j, theta=params[j, i, 2])
35+
return c
36+
37+
38+
def core(params, i, tree, n, d, tc_mpo):
39+
c = get_circuit(n, d, params)
40+
mps = c.get_quvector()
41+
e = mps.adjoint() @ tc_mpo @ mps
42+
_, nodes = tc.cons.get_tn_info(e.nodes)
43+
input_arrays = [node.tensor for node in nodes]
44+
sliced_arrays = tree.slice_arrays(input_arrays, i)
45+
return K.real(tree.contract_core(sliced_arrays, backend=backend))[0, 0]
46+
47+
48+
core_vag = K.value_and_grad(core)
49+
50+
51+
if __name__ == "__main__":
52+
nqubit = 12
53+
d = 6
54+
55+
Jx = jax.numpy.array([1.0] * (nqubit - 1)) # XX coupling strength
56+
Bz = jax.numpy.array([-1.0] * nqubit) # Transverse field strength
57+
58+
# Create TensorNetwork MPO
59+
tn_mpo = tn.matrixproductstates.mpo.FiniteTFI(Jx, Bz, dtype=np.complex64)
60+
tc_mpo = tc.quantum.tn2qop(tn_mpo)
61+
62+
# baseline results
63+
lattice = tc.templates.graphs.Line1D(nqubit, pbc=False)
64+
h = tc.quantum.heisenberg_hamiltonian(lattice, hzz=0, hyy=0, hxx=1.0, hz=-1.0)
65+
es0 = scipy.sparse.linalg.eigsh(K.numpy(h), k=1, which="SA")[0]
66+
print("exact ground state energy: ", es0)
67+
68+
params = K.implicit_randn(stddev=0.1, shape=[1, nqubit, d, 3], dtype=tc.rdtypestr)
69+
params = K.tile(params, [num_device, 1, 1, 1])
70+
71+
optimizer = optax.adam(5e-2)
72+
base_opt_state = optimizer.init(params[0])
73+
replicated_opt_state = jax.tree.map(
74+
lambda x: (
75+
jax.numpy.broadcast_to(x, (num_device,) + x.shape)
76+
if isinstance(x, jax.numpy.ndarray)
77+
else x
78+
),
79+
base_opt_state,
80+
)
81+
82+
@partial(
83+
jax.pmap,
84+
axis_name="pmap",
85+
in_axes=(0, 0, None, None, None, None, 0),
86+
static_broadcasted_argnums=(2, 3, 4, 5),
87+
)
88+
def para_vag(params, i, tree, n, d, tc_mpo, opt_state):
89+
loss, grads = core_vag(params, i, tree, n, d, tc_mpo)
90+
grads = jax.lax.psum(grads, axis_name="pmap")
91+
loss = jax.lax.psum(loss, axis_name="pmap")
92+
updates, opt_state = optimizer.update(grads, opt_state, params)
93+
params = optax.apply_updates(params, updates)
94+
return params, opt_state, loss
95+
96+
c = get_circuit(nqubit, d, params[0])
97+
mps = c.get_quvector()
98+
e = mps.adjoint() @ tc_mpo @ mps
99+
tn_info, nodes = tc.cons.get_tn_info(e.nodes)
100+
101+
opt = ctg.ReusableHyperOptimizer(
102+
parallel=True,
103+
slicing_opts={
104+
"target_slices": num_device,
105+
# "target_size": 2**20, # Add memory target
106+
},
107+
max_repeats=256,
108+
progbar=True,
109+
minimize="combo",
110+
)
111+
112+
tree = opt.search(*tn_info)
113+
114+
inds = K.arange(num_device)
115+
for j in range(100):
116+
print(f"training loop: {j}-step")
117+
params, replicated_opt_state, loss = para_vag(
118+
params, inds, tree, nqubit, d, tc_mpo, replicated_opt_state
119+
)
120+
print(loss[0])

examples/slicing_auto_pmap_vqa.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""
2+
This script illustrates how to parallelize both the contraction path
3+
finding and sliced contraction computation
4+
"""
5+
6+
from functools import partial
7+
import os
8+
9+
num_device = 8
10+
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={num_device}"
11+
import cotengra as ctg
12+
import jax
13+
import optax
14+
import tensorcircuit as tc
15+
16+
backend = "jax"
17+
K = tc.set_backend(backend)
18+
19+
20+
def get_circuit(n, d, params):
21+
c = tc.Circuit(n)
22+
for i in range(d):
23+
for j in range(0, n - 1):
24+
c.rzz(j, j + 1, theta=params[j, i, 0])
25+
for j in range(0, n):
26+
c.rx(j, theta=params[j, i, 1])
27+
return c
28+
29+
30+
def core(params, i, tree, n, d):
31+
c = get_circuit(n, d, params)
32+
nodes = c.expectation_before([tc.gates.z(), [0]], reuse=False)
33+
_, nodes = tc.cons.get_tn_info(nodes)
34+
input_arrays = [node.tensor for node in nodes]
35+
sliced_arrays = tree.slice_arrays(input_arrays, i)
36+
return K.real(tree.contract_core(sliced_arrays, backend=backend))
37+
38+
39+
core_vag = K.value_and_grad(core)
40+
41+
42+
if __name__ == "__main__":
43+
nqubit = 14
44+
d = 7
45+
46+
params = K.ones([1, nqubit, d, 2], dtype=tc.rdtypestr)
47+
params = K.tile(params, [num_device, 1, 1, 1])
48+
49+
optimizer = optax.adam(5e-2)
50+
base_opt_state = optimizer.init(params[0])
51+
replicated_opt_state = jax.tree.map(
52+
lambda x: (
53+
jax.numpy.broadcast_to(x, (num_device,) + x.shape)
54+
if isinstance(x, jax.numpy.ndarray)
55+
else x
56+
),
57+
base_opt_state,
58+
)
59+
60+
@partial(
61+
jax.pmap,
62+
axis_name="pmap",
63+
in_axes=(0, 0, None, None, None, 0),
64+
static_broadcasted_argnums=(2, 3, 4),
65+
)
66+
def para_vag(params, i, tree, n, d, opt_state):
67+
loss, grads = core_vag(params, i, tree, n, d)
68+
grads = jax.lax.psum(grads, axis_name="pmap")
69+
loss = jax.lax.psum(loss, axis_name="pmap")
70+
updates, opt_state = optimizer.update(grads, opt_state, params)
71+
params = optax.apply_updates(params, updates)
72+
return params, opt_state, loss
73+
74+
c = get_circuit(nqubit, d, params[0])
75+
nodes = c.expectation_before([tc.gates.z(), [0]], reuse=False)
76+
tn_info, _ = tc.cons.get_tn_info(nodes)
77+
78+
opt = ctg.ReusableHyperOptimizer(
79+
parallel=True,
80+
slicing_opts={
81+
"target_slices": num_device,
82+
# "target_size": 2**20, # Add memory target
83+
},
84+
max_repeats=256,
85+
progbar=True,
86+
minimize="combo",
87+
)
88+
89+
tree = opt.search(*tn_info)
90+
91+
inds = K.arange(num_device)
92+
for j in range(20):
93+
print(f"training loop: {j}-step")
94+
params, replicated_opt_state, loss = para_vag(
95+
params, inds, tree, nqubit, d, replicated_opt_state
96+
)
97+
print(loss[0])

tensorcircuit/cons.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import sys
99
import time
1010
from contextlib import contextmanager
11-
from functools import partial, reduce, wraps
11+
from functools import partial, reduce, wraps, lru_cache
1212
from operator import mul
1313
from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, Tuple
1414

@@ -439,6 +439,28 @@ def tn_greedy_contractor(
439439
# base = tn.contractors.opt_einsum_paths.path_contractors.base
440440
# utils = tn.contractors.opt_einsum_paths.utils
441441

442+
_einsum_symbols_base = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
443+
444+
445+
@lru_cache(2**14)
446+
def get_symbol(i: int) -> str:
447+
"""Get the symbol corresponding to int ``i`` - runs through the usual 52
448+
letters before resorting to unicode characters, starting at ``chr(192)``
449+
and skipping surrogates. From cotengra codebase
450+
"""
451+
if i < 52:
452+
# use a-z, A-Z first
453+
return _einsum_symbols_base[i]
454+
455+
# then proceed from 'À'
456+
i += 140
457+
458+
if i >= 55296:
459+
# Skip chr(57343) - chr(55296) as surrogates
460+
i += 2048
461+
462+
return chr(i)
463+
442464

443465
def _get_path(
444466
nodes: List[tn.Node], algorithm: Any
@@ -451,6 +473,16 @@ def _get_path(
451473
return algorithm(input_sets, output_set, size_dict), nodes
452474

453475

476+
def _identity(*args: Any, **kws: Any) -> Any:
477+
return args
478+
479+
480+
def _sort_tuple_list(input_list: List[Any], output_list: List[Any]) -> List[Any]:
481+
sorted_elements = [(tuple(sorted(t)), i) for i, t in enumerate(input_list)]
482+
sorted_elements.sort()
483+
return [output_list[i] for _, i in sorted_elements]
484+
485+
454486
def _get_path_cache_friendly(
455487
nodes: List[tn.Node], algorithm: Any
456488
) -> Tuple[List[Tuple[int, int]], List[tn.Node]]:
@@ -460,18 +492,21 @@ def _get_path_cache_friendly(
460492
for n in nodes:
461493
for e in n:
462494
if id(e) not in mapping_dict:
463-
mapping_dict[id(e)] = i
495+
mapping_dict[id(e)] = get_symbol(i)
464496
i += 1
465497
# TODO(@refraction-ray): may be not that cache friendly, since the edge id correspondence is not that fixed?
466-
input_sets = [set([mapping_dict[id(e)] for e in node.edges]) for node in nodes]
467-
placeholder = [[1e20 for _ in range(100)]]
468-
order = np.argsort(np.array(list(map(sorted, input_sets)) + placeholder, dtype=object))[:-1] # type: ignore
469-
nodes_new = [nodes[i] for i in order]
498+
input_sets = [list([mapping_dict[id(e)] for e in node.edges]) for node in nodes]
499+
# placeholder = [[1e20 for _ in range(100)]]
500+
# order = np.argsort(np.array(list(map(sorted, input_sets)), dtype=object)) # type: ignore
501+
# nodes_new = [nodes[i] for i in order]
502+
nodes_new = _sort_tuple_list(input_sets, nodes)
470503
if isinstance(algorithm, list):
471504
return algorithm, nodes_new
472505

473-
input_sets = [set([mapping_dict[id(e)] for e in node.edges]) for node in nodes_new]
474-
output_set = set([mapping_dict[id(e)] for e in tn.get_subgraph_dangling(nodes_new)])
506+
input_sets = [list([mapping_dict[id(e)] for e in node.edges]) for node in nodes_new]
507+
output_set = list(
508+
[mapping_dict[id(e)] for e in tn.get_subgraph_dangling(nodes_new)]
509+
)
475510
size_dict = {
476511
mapping_dict[id(edge)]: edge.dimension for edge in tn.get_all_edges(nodes_new)
477512
}
@@ -483,6 +518,9 @@ def _get_path_cache_friendly(
483518
# directly get input_sets, output_set and size_dict by using identity function as algorithm
484519

485520

521+
get_tn_info = partial(_get_path_cache_friendly, algorithm=_identity)
522+
523+
486524
# some contractor setup usages
487525
"""
488526
import cotengra as ctg
@@ -513,7 +551,8 @@ def _get_path_cache_friendly(
513551
514552
def opt_reconf(inputs, output, size, **kws):
515553
tree = opt.search(inputs, output, size)
516-
tree_r = tree.subtree_reconfigure_forest(progbar=True, num_trees=10, num_restarts=20, subtree_weight_what=("size", ))
554+
tree_r = tree.subtree_reconfigure_forest(progbar=True, num_trees=10,
555+
num_restarts=20, subtree_weight_what=("size", ))
517556
return tree_r.get_path()
518557
519558
tc.set_contractor("custom", optimizer=opt_reconf)

0 commit comments

Comments
 (0)