Skip to content

Commit e105763

Browse files
add psd option for sqrtm
1 parent 3090498 commit e105763

File tree

3 files changed

+13
-2
lines changed

3 files changed

+13
-2
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
- Add `u1_inds`, `u1_mask`, `u1_project`, and `u1_enlarge` functions in quantum.py as utilities in charged conservation systems
1414

15+
- Add `psd` boolean to `sqrtmh` method for backends
16+
1517
### Fixed
1618

1719
- Fix customized jax eigh operator by noting the return is a namedtuple

tensorcircuit/backends/abstract_backend.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,22 @@ def expm(self: Any, a: Tensor) -> Tensor:
4646
"Backend '{}' has not implemented `expm`.".format(self.name)
4747
)
4848

49-
def sqrtmh(self: Any, a: Tensor) -> Tensor:
49+
def sqrtmh(self: Any, a: Tensor, psd: bool = False) -> Tensor:
5050
"""
5151
Return the sqrtm of a Hermitian matrix ``a``.
5252
5353
:param a: tensor in matrix form
5454
:type a: Tensor
55+
:param psd: whether the input ``a`` is guaranteed as a positive semidefinite matrix,
56+
defaults False
57+
:type psd: bool
5558
:return: sqrtm of ``a``
5659
:rtype: Tensor
5760
"""
5861
# maybe friendly for AD and also cosidering that several backend has no support for native sqrtm
5962
e, v = self.eigh(a)
63+
if psd:
64+
e = self.relu(e)
6065
e = self.sqrt(e)
6166
return v @ self.diagflat(e) @ self.adjoint(v)
6267

tests/test_backends.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,12 @@ def test_backend_methods(backend):
119119
ans = np.array([[1, 0.5j], [-0.5j, 1]])
120120
ans2 = ans @ ans
121121
ansp = tc.backend.sqrtmh(tc.array_to_tensor(ans2))
122-
print(ansp @ ansp, ans @ ans)
122+
# print(ansp @ ansp, ans @ ans)
123123
np.testing.assert_allclose(ansp @ ansp, ans @ ans, atol=1e-4)
124+
singularm = np.array([[4.0, 0], [0, -1e-3]])
125+
np.testing.assert_allclose(
126+
tc.backend.sqrtmh(singularm, psd=True), np.array([[2.0, 0], [0, 0]]), atol=1e-5
127+
)
124128

125129
np.testing.assert_allclose(
126130
tc.backend.sum(tc.array_to_tensor(np.arange(4))), 6, atol=1e-4

0 commit comments

Comments
 (0)