@@ -5498,6 +5498,7 @@ def get_npdm(
5498
5498
algo_type = None ,
5499
5499
npdm_expr = None ,
5500
5500
mask = None ,
5501
+ index_masks = None ,
5501
5502
simulated_parallel = 0 ,
5502
5503
fused_contraction_rotation = True ,
5503
5504
cutoff = 1e-24 ,
@@ -5537,6 +5538,9 @@ def get_npdm(
5537
5538
mask : None or list[int] or list[list[int]]
5538
5539
The mask for setting repeated indices for the operator expression.
5539
5540
Default is None, meaning that all indices can be different.
5541
+ index_masks : None or list[list[int]] or list[list[list[int]]]
5542
+ The list of allowed site indices for each operator in the operator expression.
5543
+ Default is None, meaning that all site indices are allowed.
5540
5544
simulated_parallel : int
5541
5545
Number of processors for simulating parallel algorithm serially.
5542
5546
Default is zero, meaning that the serial algorithm is used if
@@ -5629,7 +5633,7 @@ def get_npdm(
5629
5633
for _ in range (pdm_type - 1 ):
5630
5634
op_str = su2_coupling % op_str
5631
5635
perm = bw .b .SpinPermScheme .initialize_su2 (
5632
- pdm_type * 2 , op_str , True ,
5636
+ int ( pdm_type * 2 ) , op_str , True ,
5633
5637
mask = bw .b .VectorUInt16 () if mask is None else bw .b .VectorUInt16 (mask ),
5634
5638
max_n_sites = ket .n_sites ,
5635
5639
)
@@ -5646,41 +5650,33 @@ def get_npdm(
5646
5650
if mask is None :
5647
5651
perms = bw .b .VectorSpinPermScheme (
5648
5652
[
5649
- bw .b .SpinPermScheme .initialize_sz (pdm_type * 2 , cd , True ,
5653
+ bw .b .SpinPermScheme .initialize_sz (len ( cd ) , cd , True ,
5650
5654
mask = bw .b .VectorUInt16 (), max_n_sites = ket .n_sites ) if fermionic_ops is None else
5651
- bw .b .SpinPermScheme .initialize_sany (pdm_type * 2 , cd , fermionic_ops ,
5655
+ bw .b .SpinPermScheme .initialize_sany (len ( cd ) , cd , fermionic_ops ,
5652
5656
mask = bw .b .VectorUInt16 (), max_n_sites = ket .n_sites ) for cd in op_str
5653
5657
]
5654
5658
)
5655
5659
elif len (mask ) != 0 and not isinstance (mask [0 ], int ):
5656
5660
assert len (mask ) == len (op_str )
5657
- pts = (
5658
- [pdm_type ] * len (op_str ) if isinstance (pdm_type , int ) else pdm_type
5659
- )
5660
5661
perms = bw .b .VectorSpinPermScheme (
5661
5662
[
5662
5663
bw .b .SpinPermScheme .initialize_sz (
5663
- pt * 2 , cd , True , mask = bw .b .VectorUInt16 (xm ), max_n_sites = ket .n_sites
5664
+ len ( cd ) , cd , True , mask = bw .b .VectorUInt16 (xm ), max_n_sites = ket .n_sites
5664
5665
) if fermionic_ops is None else
5665
5666
bw .b .SpinPermScheme .initialize_sany (
5666
- pt * 2 , cd , fermionic_ops , mask = bw .b .VectorUInt16 (xm ), max_n_sites = ket .n_sites
5667
- )
5668
- for cd , xm , pt in zip (op_str , mask , pts )
5667
+ len (cd ), cd , fermionic_ops , mask = bw .b .VectorUInt16 (xm ), max_n_sites = ket .n_sites
5668
+ ) for cd , xm in zip (op_str , mask )
5669
5669
]
5670
5670
)
5671
5671
else :
5672
- pts = (
5673
- [pdm_type ] * len (op_str ) if isinstance (pdm_type , int ) else pdm_type
5674
- )
5675
5672
perms = bw .b .VectorSpinPermScheme (
5676
5673
[
5677
5674
bw .b .SpinPermScheme .initialize_sz (
5678
- pt * 2 , cd , True , mask = bw .b .VectorUInt16 (mask ), max_n_sites = ket .n_sites
5675
+ len ( cd ) , cd , True , mask = bw .b .VectorUInt16 (mask ), max_n_sites = ket .n_sites
5679
5676
) if fermionic_ops is None else
5680
5677
bw .b .SpinPermScheme .initialize_sany (
5681
- pt * 2 , cd , fermionic_ops , mask = bw .b .VectorUInt16 (mask ), max_n_sites = ket .n_sites
5682
- )
5683
- for cd , pt in zip (op_str , pts )
5678
+ len (cd ), cd , fermionic_ops , mask = bw .b .VectorUInt16 (mask ), max_n_sites = ket .n_sites
5679
+ ) for cd in op_str
5684
5680
]
5685
5681
)
5686
5682
elif SymmetryTypes .SGF in bw .symm_type :
@@ -5693,41 +5689,34 @@ def get_npdm(
5693
5689
if mask is None :
5694
5690
perms = bw .b .VectorSpinPermScheme (
5695
5691
[
5696
- bw .b .SpinPermScheme .initialize_sz (pdm_type * 2 , cd , True ,
5692
+ bw .b .SpinPermScheme .initialize_sz (len ( cd ) , cd , True ,
5697
5693
mask = bw .b .VectorUInt16 (), max_n_sites = ket .n_sites ) if fermionic_ops is None else
5698
- bw .b .SpinPermScheme .initialize_sany (pdm_type * 2 , cd , fermionic_ops ,
5694
+ bw .b .SpinPermScheme .initialize_sany (len ( cd ) , cd , fermionic_ops ,
5699
5695
mask = bw .b .VectorUInt16 (), max_n_sites = ket .n_sites ) for cd in op_str
5700
5696
]
5701
5697
)
5702
5698
elif len (mask ) != 0 and not isinstance (mask [0 ], int ):
5703
5699
assert len (mask ) == len (op_str )
5704
- pts = (
5705
- [pdm_type ] * len (op_str ) if isinstance (pdm_type , int ) else pdm_type
5706
- )
5707
5700
perms = bw .b .VectorSpinPermScheme (
5708
5701
[
5709
5702
bw .b .SpinPermScheme .initialize_sz (
5710
- pt * 2 , cd , True , mask = bw .b .VectorUInt16 (xm ), max_n_sites = ket .n_sites
5703
+ len ( cd ) , cd , True , mask = bw .b .VectorUInt16 (xm ), max_n_sites = ket .n_sites
5711
5704
) if fermionic_ops is None else
5712
5705
bw .b .SpinPermScheme .initialize_sany (
5713
- pt * 2 , cd , fermionic_ops , mask = bw .b .VectorUInt16 (xm ), max_n_sites = ket .n_sites
5706
+ len ( cd ) , cd , fermionic_ops , mask = bw .b .VectorUInt16 (xm ), max_n_sites = ket .n_sites
5714
5707
)
5715
- for cd , xm , pt in zip (op_str , mask , pts )
5708
+ for cd , xm in zip (op_str , mask )
5716
5709
]
5717
5710
)
5718
5711
else :
5719
- pts = (
5720
- [pdm_type ] * len (op_str ) if isinstance (pdm_type , int ) else pdm_type
5721
- )
5722
5712
perms = bw .b .VectorSpinPermScheme (
5723
5713
[
5724
5714
bw .b .SpinPermScheme .initialize_sz (
5725
- pt * 2 , cd , True , mask = bw .b .VectorUInt16 (mask ), max_n_sites = ket .n_sites
5715
+ len ( cd ) , cd , True , mask = bw .b .VectorUInt16 (mask ), max_n_sites = ket .n_sites
5726
5716
) if fermionic_ops is None else
5727
5717
bw .b .SpinPermScheme .initialize_sany (
5728
- pt * 2 , cd , fermionic_ops , mask = bw .b .VectorUInt16 (mask ), max_n_sites = ket .n_sites
5729
- )
5730
- for cd , pt in zip (op_str , pts )
5718
+ len (cd ), cd , fermionic_ops , mask = bw .b .VectorUInt16 (mask ), max_n_sites = ket .n_sites
5719
+ ) for cd in op_str
5731
5720
]
5732
5721
)
5733
5722
elif SymmetryTypes .SAny in bw .symm_type :
@@ -5761,6 +5750,19 @@ def get_npdm(
5761
5750
]
5762
5751
)
5763
5752
5753
+ if index_masks is not None :
5754
+ if any (any (isinstance (im , (int , np .int64 )) for im in ims ) for ims in index_masks if len (ims ) != 0 ):
5755
+ index_masks = [index_masks ]
5756
+ if len (index_masks ) == 1 :
5757
+ index_masks = index_masks * len (perms )
5758
+ assert len (perms ) == len (index_masks )
5759
+ for perm , ims in zip (perms , index_masks ):
5760
+ f = lambda x : [int (px ) for px in x ]
5761
+ if self .reorder_idx is not None :
5762
+ rev_idx = np .argsort (self .reorder_idx )
5763
+ f = lambda x : [rev_idx [int (px )] for px in x ]
5764
+ perm .index_mask = bw .b .VectorVectorUInt16 ([bw .b .VectorUInt16 (f (x )) for x in ims ])
5765
+
5764
5766
if iprint >= 1 :
5765
5767
print ("npdm string =" , op_str )
5766
5768
@@ -5949,7 +5951,7 @@ def get_npdm(
5949
5951
for ip in range (len (npdms )):
5950
5952
npdms [ip ] = np .asarray (npdms [ip ])
5951
5953
5952
- if self .reorder_idx is not None :
5954
+ if self .reorder_idx is not None and index_masks is None :
5953
5955
rev_idx = np .argsort (self .reorder_idx )
5954
5956
for ip in range (len (npdms )):
5955
5957
for i in range (npdms [ip ].ndim ):
0 commit comments