22
22
tc .set_dtype ("complex128" )
23
23
24
24
25
- def get_circuit (n , d , params ):
25
+ def circuit2nodes (n , d , params , tc_mpo ):
26
26
c = tc .Circuit (n )
27
27
c .h (range (n ))
28
28
for i in range (d ):
@@ -32,14 +32,15 @@ def get_circuit(n, d, params):
32
32
c .rx (j , theta = params [j , i , 1 ])
33
33
for j in range (n ):
34
34
c .ry (j , theta = params [j , i , 2 ])
35
- return c
36
35
37
-
38
- def core (params , i , tree , n , d , tc_mpo ):
39
- c = get_circuit (n , d , params )
40
36
mps = c .get_quvector ()
41
37
e = mps .adjoint () @ tc_mpo @ mps
42
- _ , nodes = tc .cons .get_tn_info (e .nodes )
38
+ return e .nodes
39
+
40
+
41
+ def core (params , i , tree , n , d , tc_mpo ):
42
+ nodes = circuit2nodes (n , d , params , tc_mpo )
43
+ _ , nodes = tc .cons .get_tn_info (nodes )
43
44
input_arrays = [node .tensor for node in nodes ]
44
45
sliced_arrays = tree .slice_arrays (input_arrays , i )
45
46
return K .real (tree .contract_core (sliced_arrays , backend = backend ))[0 , 0 ]
@@ -52,24 +53,18 @@ def core(params, i, tree, n, d, tc_mpo):
52
53
nqubit = 12
53
54
d = 6
54
55
55
- Jx = jax .numpy .array ([1.0 ] * (nqubit - 1 )) # XX coupling strength
56
- Bz = jax .numpy .array ([- 1.0 ] * nqubit ) # Transverse field strength
57
-
58
- # Create TensorNetwork MPO
59
- tn_mpo = tn .matrixproductstates .mpo .FiniteTFI (Jx , Bz , dtype = np .complex64 )
60
- tc_mpo = tc .quantum .tn2qop (tn_mpo )
61
-
62
56
# baseline results
63
57
lattice = tc .templates .graphs .Line1D (nqubit , pbc = False )
64
58
h = tc .quantum .heisenberg_hamiltonian (lattice , hzz = 0 , hyy = 0 , hxx = 1.0 , hz = - 1.0 )
65
59
es0 = scipy .sparse .linalg .eigsh (K .numpy (h ), k = 1 , which = "SA" )[0 ]
66
60
print ("exact ground state energy: " , es0 )
67
61
68
- params = K .implicit_randn (stddev = 0.1 , shape = [1 , nqubit , d , 3 ], dtype = tc .rdtypestr )
69
- params = K .tile (params , [num_device , 1 , 1 , 1 ])
62
+ params = K .implicit_randn (stddev = 0.1 , shape = [nqubit , d , 3 ], dtype = tc .rdtypestr )
63
+ replicated_params = K .reshape (params , [1 ] + list (params .shape ))
64
+ replicated_params = K .tile (replicated_params , [num_device , 1 , 1 , 1 ])
70
65
71
66
optimizer = optax .adam (5e-2 )
72
- base_opt_state = optimizer .init (params [ 0 ] )
67
+ base_opt_state = optimizer .init (params )
73
68
replicated_opt_state = jax .tree .map (
74
69
lambda x : (
75
70
jax .numpy .broadcast_to (x , (num_device ,) + x .shape )
@@ -93,28 +88,32 @@ def para_vag(params, i, tree, n, d, tc_mpo, opt_state):
93
88
params = optax .apply_updates (params , updates )
94
89
return params , opt_state , loss
95
90
96
- c = get_circuit (nqubit , d , params [0 ])
97
- mps = c .get_quvector ()
98
- e = mps .adjoint () @ tc_mpo @ mps
99
- tn_info , nodes = tc .cons .get_tn_info (e .nodes )
91
+ Jx = jax .numpy .array ([1.0 ] * (nqubit - 1 )) # XX coupling strength
92
+ Bz = jax .numpy .array ([- 1.0 ] * nqubit ) # Transverse field strength
93
+ # Create TensorNetwork MPO
94
+ tn_mpo = tn .matrixproductstates .mpo .FiniteTFI (Jx , Bz , dtype = np .complex64 )
95
+ tc_mpo = tc .quantum .tn2qop (tn_mpo )
100
96
97
+ nodes = circuit2nodes (nqubit , d , params , tc_mpo )
98
+ tn_info , _ = tc .cons .get_tn_info (nodes )
99
+
100
+ # Create ReusableHyperOptimizer for finding optimal contraction paths
101
101
opt = ctg .ReusableHyperOptimizer (
102
- parallel = True ,
102
+ parallel = True , # Enable parallel path finding
103
103
slicing_opts = {
104
- "target_slices" : num_device ,
105
- # "target_size": 2**20, # Add memory target
104
+ "target_slices" : num_device , # Split computation across available devices
105
+ # "target_size": 2**20, # Optional: Set memory limit per slice
106
106
},
107
- max_repeats = 256 ,
108
- progbar = True ,
109
- minimize = "combo" ,
107
+ max_repeats = 256 , # Maximum number of path finding attempts
108
+ progbar = True , # Show progress bar during optimization
109
+ minimize = "combo" , # Optimize for both time and memory
110
110
)
111
-
112
111
tree = opt .search (* tn_info )
113
112
114
113
inds = K .arange (num_device )
115
114
for j in range (100 ):
116
115
print (f"training loop: { j } -step" )
117
- params , replicated_opt_state , loss = para_vag (
118
- params , inds , tree , nqubit , d , tc_mpo , replicated_opt_state
116
+ replicated_params , replicated_opt_state , loss = para_vag (
117
+ replicated_params , inds , tree , nqubit , d , tc_mpo , replicated_opt_state
119
118
)
120
119
print (loss [0 ])
0 commit comments