Skip to content

Commit e040057

Browse files
Add view_as_complex and view_as_real
1 parent c62a1e3 commit e040057

File tree

4 files changed

+135
-5
lines changed

4 files changed

+135
-5
lines changed

python/paddle/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,8 @@
362362
unstack,
363363
view,
364364
view_as,
365+
view_as_complex,
366+
view_as_real,
365367
vsplit,
366368
vstack,
367369
)
@@ -1155,7 +1157,9 @@
11551157
'acosh',
11561158
'atanh',
11571159
'as_complex',
1160+
'view_as_complex',
11581161
'as_real',
1162+
'view_as_real',
11591163
'diff',
11601164
'angle',
11611165
'fmax',

python/paddle/tensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@
226226
unstack,
227227
view,
228228
view_as,
229+
view_as_complex,
230+
view_as_real,
229231
vsplit,
230232
vstack,
231233
)
@@ -776,7 +778,9 @@
776778
'lu_unpack',
777779
'cdist',
778780
'as_complex',
781+
'view_as_complex',
779782
'as_real',
783+
'view_as_real',
780784
'rad2deg',
781785
'deg2rad',
782786
'gcd',

python/paddle/tensor/manipulation.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6242,6 +6242,83 @@ def as_real(x: Tensor, name: str | None = None) -> Tensor:
62426242
return out
62436243

62446244

6245+
def view_as_complex(input: Tensor) -> Tensor:
6246+
"""Return a complex tensor that is a view of the input real tensor .
6247+
6248+
The data type of the input tensor is 'float32' or 'float64', and the data
6249+
type of the returned tensor is 'complex64' or 'complex128', respectively.
6250+
6251+
The shape of the input tensor is ``(* ,2)``, (``*`` means arbitrary shape), i.e.
6252+
the size of the last axis should be 2, which represent the real and imag part
6253+
of a complex number. The shape of the returned tensor is ``(*,)``.
6254+
6255+
The complex tensor is a view of the input real tensor, meaning that it shares the same memory with real tensor.
6256+
6257+
The image below demonstrates the case that a real 3D-tensor with shape [2, 3, 2] is transformed into a complex 2D-tensor with shape [2, 3].
6258+
6259+
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/as_complex.png
6260+
:width: 500
6261+
:alt: Illustration of as_complex
6262+
:align: center
6263+
6264+
Args:
6265+
input (Tensor): The input tensor. Data type is 'float32' or 'float64'.
6266+
6267+
Returns:
6268+
Tensor, The output. Data type is 'complex64' or 'complex128', sharing the same memory with input.
6269+
6270+
Examples:
6271+
.. code-block:: python
6272+
6273+
>>> import paddle
6274+
>>> x = paddle.arange(12, dtype=paddle.float32).reshape([2, 3, 2])
6275+
>>> y = paddle.as_complex(x)
6276+
>>> print(y)
6277+
Tensor(shape=[2, 3], dtype=complex64, place=Place(cpu), stop_gradient=True,
6278+
[[1j , (2+3j) , (4+5j) ],
6279+
[(6+7j) , (8+9j) , (10+11j)]])
6280+
"""
6281+
6282+
return as_complex(x=input)
6283+
6284+
6285+
def view_as_real(input: Tensor) -> Tensor:
6286+
"""Return a real tensor that is a view of the input complex tensor.
6287+
6288+
The data type of the input tensor is 'complex64' or 'complex128', and the data
6289+
type of the returned tensor is 'float32' or 'float64', respectively.
6290+
6291+
When the shape of the input tensor is ``(*, )``, (``*`` means arbitrary shape),
6292+
the shape of the output tensor is ``(*, 2)``, i.e. the shape of the output is
6293+
the shape of the input appended by an extra ``2``.
6294+
6295+
The real tensor is a view of the input complex tensor, meaning that it shares the same memory with complex tensor.
6296+
6297+
Args:
6298+
input (Tensor): The input tensor. Data type is 'complex64' or 'complex128'.
6299+
6300+
Returns:
6301+
Tensor, The output. Data type is 'float32' or 'float64', sharing the same memory with input.
6302+
6303+
Examples:
6304+
.. code-block:: python
6305+
6306+
>>> import paddle
6307+
>>> x = paddle.arange(12, dtype=paddle.float32).reshape([2, 3, 2])
6308+
>>> y = paddle.as_complex(x)
6309+
>>> z = paddle.as_real(y)
6310+
>>> print(z)
6311+
Tensor(shape=[2, 3, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
6312+
[[[0. , 1. ],
6313+
[2. , 3. ],
6314+
[4. , 5. ]],
6315+
[[6. , 7. ],
6316+
[8. , 9. ],
6317+
[10., 11.]]])
6318+
"""
6319+
return as_real(x=input)
6320+
6321+
62456322
def repeat_interleave(
62466323
x: Tensor,
62476324
repeats: int | Tensor,

test/legacy_test/test_complex_view_op.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def ref_view_as_real(x):
3333
return np.stack([x.real, x.imag], -1)
3434

3535

36-
class TestViewAsComplexOp(OpTest):
36+
class TestAsComplexOp(OpTest):
3737
def setUp(self):
3838
self.op_type = "as_complex"
3939
self.python_api = paddle.as_complex
@@ -53,7 +53,7 @@ def test_check_grad(self):
5353
)
5454

5555

56-
class TestViewAsRealOp(OpTest):
56+
class TestAsRealOp(OpTest):
5757
def setUp(self):
5858
self.op_type = "as_real"
5959
real = np.random.randn(10, 10).astype("float64")
@@ -75,7 +75,7 @@ def test_check_grad(self):
7575
)
7676

7777

78-
class TestViewAsComplexAPI(unittest.TestCase):
78+
class TestAsComplexAPI(unittest.TestCase):
7979
def setUp(self):
8080
self.x = np.random.randn(10, 10, 2)
8181
self.out = ref_view_as_complex(self.x)
@@ -98,7 +98,7 @@ def test_static(self):
9898
np.testing.assert_allclose(self.out, out_np, rtol=1e-05)
9999

100100

101-
class TestViewAsRealAPI(unittest.TestCase):
101+
class TestAsRealAPI(unittest.TestCase):
102102
def setUp(self):
103103
self.x = np.random.randn(10, 10) + 1j * np.random.randn(10, 10)
104104
self.out = ref_view_as_real(self.x)
@@ -121,7 +121,7 @@ def test_static(self):
121121
np.testing.assert_allclose(self.out, out_np, rtol=1e-05)
122122

123123

124-
class TestViewAsRealAPI_ZeroSize(unittest.TestCase):
124+
class TestAsRealAPI_ZeroSize(unittest.TestCase):
125125
def setUp(self):
126126
self.x = np.random.randn(10, 0) + 1j * np.random.randn(10, 0)
127127
self.out = ref_view_as_real(self.x)
@@ -137,5 +137,50 @@ def test_dygraph(self):
137137
np.testing.assert_allclose(x_tensor.grad.shape, x_tensor.shape)
138138

139139

140+
class TestViewAsComplexAPI(unittest.TestCase):
141+
def setUp(self):
142+
self.x = np.random.randn(10, 10, 2)
143+
self.out = ref_view_as_complex(self.x)
144+
145+
def test_dygraph(self):
146+
with dygraph.guard():
147+
x = paddle.to_tensor(self.x)
148+
out = paddle.view_as_complex(x)
149+
out_np = out.numpy()
150+
self.assertEqual(out.data_ptr(), x.data_ptr())
151+
np.testing.assert_allclose(self.out, out_np, rtol=1e-05)
152+
153+
154+
class TestViewAsRealAPI(unittest.TestCase):
155+
def setUp(self):
156+
self.x = np.random.randn(10, 10) + 1j * np.random.randn(10, 10)
157+
self.out = ref_view_as_real(self.x)
158+
159+
def test_dygraph(self):
160+
with dygraph.guard():
161+
x = paddle.to_tensor(self.x)
162+
out = paddle.view_as_real(x)
163+
out_np = out.numpy()
164+
self.assertEqual(out.data_ptr(), x.data_ptr())
165+
np.testing.assert_allclose(self.out, out_np, rtol=1e-05)
166+
167+
168+
class TestViewAsRealAPI_ZeroSize(unittest.TestCase):
169+
def setUp(self):
170+
self.x = np.random.randn(10, 0) + 1j * np.random.randn(10, 0)
171+
self.out = ref_view_as_real(self.x)
172+
173+
def test_dygraph(self):
174+
for place in get_places():
175+
with dygraph.guard(place):
176+
x_tensor = paddle.to_tensor(self.x)
177+
x_tensor.stop_gradient = False
178+
out = paddle.view_as_real(x_tensor)
179+
np.testing.assert_allclose(self.out, out.numpy(), rtol=1e-05)
180+
self.assertEqual(out.data_ptr(), x_tensor.data_ptr())
181+
out.sum().backward()
182+
np.testing.assert_allclose(x_tensor.grad.shape, x_tensor.shape)
183+
184+
140185
if __name__ == "__main__":
141186
unittest.main()

0 commit comments

Comments
 (0)