Skip to content

Commit 0660abf

Browse files
committed
Merge remote-tracking branch 'upstream/master' into feature/lattice-updates
2 parents 7063c6f + 03fa1c5 commit 0660abf

File tree

4 files changed

+30
-11
lines changed

4 files changed

+30
-11
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434

3535
- `MPSCircuit` now will first try to transform `QuVector` input to tensors directly instead of evaluating it to dense wavefunction first.
3636

37+
- Fix to use `status` for `circuit.sample` when `allow_state=True`.
38+
3739
### Changed
3840

3941
- The order of arguments of `tc.timeevol.ed_evol` are changed for consistent interface with other evolution methods.

tensorcircuit/backends/tensorflow_backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,9 @@ def argsort(self, a: Tensor, axis: int = -1) -> Tensor:
564564
def sort(self, a: Tensor, axis: int = -1) -> Tensor:
565565
return tf.sort(a, axis=axis)
566566

567+
def shape_tuple(self, a: Tensor) -> Tuple[int, ...]:
568+
return tuple(a.shape)
569+
567570
def unique_with_counts(self, a: Tensor, **kws: Any) -> Tuple[Tensor, Tensor]:
568571
r = tf.unique_with_counts(a)
569572
order = tf.argsort(r.y)

tensorcircuit/basecircuit.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,8 @@ def sample(
576576
:param random_generator: random generator, defaults to None
577577
:type random_generator: Optional[Any], optional
578578
:param status: external randomness given by tensor uniformly from [0, 1],
579-
if set, can overwrite random_generator
579+
if set, can overwrite random_generator, shape [batch] for `allow_state=True`
580+
and shape [batch, nqubits] for `allow_state=False` using perfect sampling implementation
580581
:type status: Optional[Tensor]
581582
:param jittable: when converting to count, whether keep the full size. if false, may be conflict
582583
external jit, if true, may fail for large scale system with actual limited count results
@@ -597,20 +598,24 @@ def sample(
597598
return r
598599
r = [r] # type: ignore
599600
else:
601+
r = [] # type: ignore
602+
if status is not None:
603+
assert backend.shape_tuple(status)[0] == batch
604+
for seed in status:
605+
r.append(self.perfect_sampling(seed)) # type: ignore
600606

601-
@backend.jit
602-
def perfect_sampling(key: Any) -> Any:
603-
backend.set_random_state(key)
604-
return self.perfect_sampling()
607+
else:
605608

606-
# TODO(@refraction-ray): status is not used here
609+
@backend.jit
610+
def perfect_sampling(key: Any) -> Any:
611+
backend.set_random_state(key)
612+
return self.perfect_sampling()
607613

608-
r = [] # type: ignore
614+
subkey = random_generator
615+
for _ in range(batch):
616+
key, subkey = backend.random_split(subkey)
617+
r.append(perfect_sampling(key)) # type: ignore
609618

610-
subkey = random_generator
611-
for _ in range(batch):
612-
key, subkey = backend.random_split(subkey)
613-
r.append(perfect_sampling(key)) # type: ignore
614619
if format is None:
615620
return r
616621
r = backend.stack([ri[0] for ri in r]) # type: ignore

tests/test_circuit.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,6 +1285,7 @@ def test_batch_sample(backend):
12851285
c.H(0)
12861286
c.cnot(0, 1)
12871287
print(c.sample())
1288+
print(c.sample(batch=8, status=np.random.uniform(size=[8, 3])))
12881289
print(c.sample(batch=8))
12891290
print(c.sample(random_generator=tc.backend.get_random_state(42)))
12901291
print(c.sample(allow_state=True))
@@ -1302,6 +1303,14 @@ def test_batch_sample(backend):
13021303
format="sample_bin",
13031304
)
13041305
)
1306+
print(
1307+
c.sample(
1308+
batch=8,
1309+
allow_state=False,
1310+
status=np.random.uniform(size=[8, 3]),
1311+
format="sample_bin",
1312+
)
1313+
)
13051314

13061315

13071316
def test_expectation_y_bug():

0 commit comments

Comments
 (0)