Skip to content

Commit 1ad13b1

Browse files
add vmap+grad example
1 parent 1f426d3 commit 1ad13b1

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

examples/nested_vmap_grad.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""
2+
demonstration of vmap + grad like API
3+
"""
4+
5+
import tensorcircuit as tc
6+
7+
# See issues in https://github.com/tencent-quantum-lab/tensorcircuit/issues/229#issuecomment-2600773780
8+
9+
for backend in ["tensorflow", "jax"]:
10+
with tc.runtime_backend(backend) as K:
11+
L = 2
12+
inputs = K.cast(K.ones([3, 2]), tc.rdtypestr)
13+
weights = K.cast(K.ones([2]), tc.rdtypestr)
14+
15+
def ansatz(thetas, alpha):
16+
c = tc.Circuit(L)
17+
for j in range(2):
18+
for i in range(L):
19+
c.rx(i, theta=thetas[j])
20+
c.ry(i, theta=alpha[j])
21+
for i in range(L - 1):
22+
c.cnot(i, i + 1)
23+
return c
24+
25+
def f(thetas, alpha):
26+
c = ansatz(thetas, alpha)
27+
observables = K.stack([K.real(c.expectation_ps(z=[i])) for i in range(L)])
28+
return K.mean(observables)
29+
30+
# f_vmap = K.vmap(f, vectorized_argnums=0)
31+
32+
print("grad", K.grad(f)(inputs[0], weights))
33+
print("vmap", K.vmap(f)(inputs, weights))
34+
print("vmap over grad", K.vmap(K.grad(f))(inputs, weights))
35+
# wrong in tf due to https://github.com/google/TensorNetwork/issues/940
36+
# https://github.com/tensorflow/tensorflow/issues/52148
37+
print("vmap over jacfwd", K.vmap(K.jacfwd(f))(inputs, weights))
38+
print("jacfwd over vmap", K.jacfwd(K.vmap(f))(inputs, weights))
39+
r = K.vmap(K.jacrev(f))(inputs, weights)
40+
print("vmap over jacrev", r)
41+
# wrong in tf
42+
r = K.jacrev(K.vmap(f))(inputs, weights)
43+
print("jacrev over vmap", r)
44+
r = K.vmap(K.hessian(f))(inputs, weights)
45+
print("vmap over hess", r)
46+
# wrong in tf
47+
r = K.hessian(K.vmap(f))(inputs, weights)
48+
print("hess over vmap", r)
49+
50+
# lessons: never put vmap outside gradient function in tf

0 commit comments

Comments
 (0)