Skip to content

Commit 625de73

Browse files
add jittable option in sample
1 parent 3eab8ba commit 625de73

File tree

3 files changed

+24
-1
lines changed

3 files changed

+24
-1
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## Unreleased
44

5+
### Added
6+
7+
- Add `jittable` option in `c.sample()` method, friendly to switch off for large scale sample
8+
59
### Changed
610

711
- Downgrading some logger warning to info

tensorcircuit/basecircuit.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,7 @@ def sample(
528528
format: Optional[str] = None,
529529
random_generator: Optional[Any] = None,
530530
status: Optional[Tensor] = None,
531+
jittable: bool = True,
531532
) -> Any:
532533
"""
533534
batched sampling from state or circuit tensor network directly
@@ -547,6 +548,9 @@ def sample(
547548
:param status: external randomness given by tensor uniformly from [0, 1],
548549
if set, can overwrite random_generator
549550
:type status: Optional[Tensor]
551+
:param jittable: when converting to count, whether keep the full size. if false, may be conflict
552+
external jit, if true, may fail for large scale system with actual limited count results
553+
:type jittable: bool, defaults true
550554
:return: List (if batch) of tuple (binary configuration tensor and corresponding probability)
551555
if the format is None, and consistent with format when given
552556
:rtype: Any
@@ -612,7 +616,7 @@ def perfect_sampling(key: Any) -> Any:
612616
if batch is None:
613617
r = r[0] # type: ignore
614618
return r
615-
return sample2all(sample=ch, n=self._nqubits, format=format, jittable=True)
619+
return sample2all(sample=ch, n=self._nqubits, format=format, jittable=jittable)
616620

617621
def sample_expectation_ps(
618622
self,

tests/test_circuit.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,21 @@ def f5(key):
236236
)
237237

238238

239+
@pytest.mark.parametrize("backend", [lf("jaxb")]) # too slow for np
240+
def test_large_scale_sample(backend):
241+
L = 30
242+
c = tc.Circuit(L)
243+
c.h(0)
244+
c.cnot([i for i in range(L - 1)], [i + 1 for i in range(L - 1)])
245+
results = c.sample(
246+
allow_state=False, batch=1024, format="count_dict_bin", jittable=False
247+
)
248+
assert (
249+
results["0" * L] / results["1" * L] < 1.2
250+
and results["0" * L] / results["1" * L] > 0.8
251+
)
252+
253+
239254
@pytest.mark.parametrize("backend", [lf("npb"), lf("cpb")])
240255
def test_expectation(backend):
241256
c = tc.Circuit(2)

0 commit comments

Comments
 (0)