|
| 1 | +""" |
| 2 | +demonstration of vmap + grad like API |
| 3 | +""" |
| 4 | + |
| 5 | +import tensorcircuit as tc |
| 6 | + |
| 7 | +# See issues in https://github.com/tencent-quantum-lab/tensorcircuit/issues/229#issuecomment-2600773780 |
| 8 | + |
| 9 | +for backend in ["tensorflow", "jax"]: |
| 10 | + with tc.runtime_backend(backend) as K: |
| 11 | + L = 2 |
| 12 | + inputs = K.cast(K.ones([3, 2]), tc.rdtypestr) |
| 13 | + weights = K.cast(K.ones([2]), tc.rdtypestr) |
| 14 | + |
| 15 | + def ansatz(thetas, alpha): |
| 16 | + c = tc.Circuit(L) |
| 17 | + for j in range(2): |
| 18 | + for i in range(L): |
| 19 | + c.rx(i, theta=thetas[j]) |
| 20 | + c.ry(i, theta=alpha[j]) |
| 21 | + for i in range(L - 1): |
| 22 | + c.cnot(i, i + 1) |
| 23 | + return c |
| 24 | + |
| 25 | + def f(thetas, alpha): |
| 26 | + c = ansatz(thetas, alpha) |
| 27 | + observables = K.stack([K.real(c.expectation_ps(z=[i])) for i in range(L)]) |
| 28 | + return K.mean(observables) |
| 29 | + |
| 30 | + # f_vmap = K.vmap(f, vectorized_argnums=0) |
| 31 | + |
| 32 | + print("grad", K.grad(f)(inputs[0], weights)) |
| 33 | + print("vmap", K.vmap(f)(inputs, weights)) |
| 34 | + print("vmap over grad", K.vmap(K.grad(f))(inputs, weights)) |
| 35 | + # wrong in tf due to https://github.com/google/TensorNetwork/issues/940 |
| 36 | + # https://github.com/tensorflow/tensorflow/issues/52148 |
| 37 | + print("vmap over jacfwd", K.vmap(K.jacfwd(f))(inputs, weights)) |
| 38 | + print("jacfwd over vmap", K.jacfwd(K.vmap(f))(inputs, weights)) |
| 39 | + r = K.vmap(K.jacrev(f))(inputs, weights) |
| 40 | + print("vmap over jacrev", r) |
| 41 | + # wrong in tf |
| 42 | + r = K.jacrev(K.vmap(f))(inputs, weights) |
| 43 | + print("jacrev over vmap", r) |
| 44 | + r = K.vmap(K.hessian(f))(inputs, weights) |
| 45 | + print("vmap over hess", r) |
| 46 | + # wrong in tf |
| 47 | + r = K.hessian(K.vmap(f))(inputs, weights) |
| 48 | + print("hess over vmap", r) |
| 49 | + |
| 50 | +# lessons: never put vmap outside gradient function in tf |
0 commit comments