Skip to content

Commit 8321bbb

Browse files
[API Compatibility] Add view_as_complex and view_as_real APIs (#74466)
* view as real * Cherry-pick view_as_real and view as complex * Remove Param decorator
1 parent c2f8e7c commit 8321bbb

File tree

4 files changed

+135
-6
lines changed

4 files changed

+135
-6
lines changed

python/paddle/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,8 @@
369369
unstack,
370370
view,
371371
view_as,
372+
view_as_complex,
373+
view_as_real,
372374
vsplit,
373375
vstack,
374376
)
@@ -1167,7 +1169,9 @@
11671169
'acosh',
11681170
'atanh',
11691171
'as_complex',
1172+
'view_as_complex',
11701173
'as_real',
1174+
'view_as_real',
11711175
'diff',
11721176
'angle',
11731177
'fmax',

python/paddle/tensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,8 @@
227227
unstack,
228228
view,
229229
view_as,
230+
view_as_complex,
231+
view_as_real,
230232
vsplit,
231233
vstack,
232234
)
@@ -783,7 +785,9 @@
783785
'lu_unpack',
784786
'cdist',
785787
'as_complex',
788+
'view_as_complex',
786789
'as_real',
790+
'view_as_real',
787791
'rad2deg',
788792
'deg2rad',
789793
'gcd',

python/paddle/tensor/manipulation.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6298,7 +6298,83 @@ def as_real(x: Tensor, name: str | None = None) -> Tensor:
62986298
return out
62996299

63006300

6301-
@ParamAliasDecorator({"x": ["input"], "axis": ["dim"]})
6301+
def view_as_complex(input: Tensor) -> Tensor:
6302+
"""Return a complex tensor that is a view of the input real tensor .
6303+
6304+
The data type of the input tensor is 'float32' or 'float64', and the data
6305+
type of the returned tensor is 'complex64' or 'complex128', respectively.
6306+
6307+
The shape of the input tensor is ``(* ,2)``, (``*`` means arbitrary shape), i.e.
6308+
the size of the last axis should be 2, which represent the real and imag part
6309+
of a complex number. The shape of the returned tensor is ``(*,)``.
6310+
6311+
The complex tensor is a view of the input real tensor, meaning that it shares the same memory with real tensor.
6312+
6313+
The image below demonstrates the case that a real 3D-tensor with shape [2, 3, 2] is transformed into a complex 2D-tensor with shape [2, 3].
6314+
6315+
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/as_complex.png
6316+
:width: 500
6317+
:alt: Illustration of as_complex
6318+
:align: center
6319+
6320+
Args:
6321+
input (Tensor): The input tensor. Data type is 'float32' or 'float64'.
6322+
6323+
Returns:
6324+
Tensor, The output. Data type is 'complex64' or 'complex128', sharing the same memory with input.
6325+
6326+
Examples:
6327+
.. code-block:: python
6328+
6329+
>>> import paddle
6330+
>>> x = paddle.arange(12, dtype=paddle.float32).reshape([2, 3, 2])
6331+
>>> y = paddle.as_complex(x)
6332+
>>> print(y)
6333+
Tensor(shape=[2, 3], dtype=complex64, place=Place(cpu), stop_gradient=True,
6334+
[[1j , (2+3j) , (4+5j) ],
6335+
[(6+7j) , (8+9j) , (10+11j)]])
6336+
"""
6337+
6338+
return as_complex(x=input)
6339+
6340+
6341+
def view_as_real(input: Tensor) -> Tensor:
6342+
"""Return a real tensor that is a view of the input complex tensor.
6343+
6344+
The data type of the input tensor is 'complex64' or 'complex128', and the data
6345+
type of the returned tensor is 'float32' or 'float64', respectively.
6346+
6347+
When the shape of the input tensor is ``(*, )``, (``*`` means arbitrary shape),
6348+
the shape of the output tensor is ``(*, 2)``, i.e. the shape of the output is
6349+
the shape of the input appended by an extra ``2``.
6350+
6351+
The real tensor is a view of the input complex tensor, meaning that it shares the same memory with complex tensor.
6352+
6353+
Args:
6354+
input (Tensor): The input tensor. Data type is 'complex64' or 'complex128'.
6355+
6356+
Returns:
6357+
Tensor, The output. Data type is 'float32' or 'float64', sharing the same memory with input.
6358+
6359+
Examples:
6360+
.. code-block:: python
6361+
6362+
>>> import paddle
6363+
>>> x = paddle.arange(12, dtype=paddle.float32).reshape([2, 3, 2])
6364+
>>> y = paddle.as_complex(x)
6365+
>>> z = paddle.as_real(y)
6366+
>>> print(z)
6367+
Tensor(shape=[2, 3, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
6368+
[[[0. , 1. ],
6369+
[2. , 3. ],
6370+
[4. , 5. ]],
6371+
[[6. , 7. ],
6372+
[8. , 9. ],
6373+
[10., 11.]]])
6374+
"""
6375+
return as_real(x=input)
6376+
6377+
63026378
def repeat_interleave(
63036379
x: Tensor,
63046380
repeats: int | Tensor,

test/legacy_test/test_complex_view_op.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def ref_view_as_real(x):
3333
return np.stack([x.real, x.imag], -1)
3434

3535

36-
class TestViewAsComplexOp(OpTest):
36+
class TestAsComplexOp(OpTest):
3737
def setUp(self):
3838
self.op_type = "as_complex"
3939
self.python_api = paddle.as_complex
@@ -53,7 +53,7 @@ def test_check_grad(self):
5353
)
5454

5555

56-
class TestViewAsRealOp(OpTest):
56+
class TestAsRealOp(OpTest):
5757
def setUp(self):
5858
self.op_type = "as_real"
5959
real = np.random.randn(10, 10).astype("float64")
@@ -75,7 +75,7 @@ def test_check_grad(self):
7575
)
7676

7777

78-
class TestViewAsComplexAPI(unittest.TestCase):
78+
class TestAsComplexAPI(unittest.TestCase):
7979
def setUp(self):
8080
self.x = np.random.randn(10, 10, 2)
8181
self.out = ref_view_as_complex(self.x)
@@ -98,7 +98,7 @@ def test_static(self):
9898
np.testing.assert_allclose(self.out, out_np, rtol=1e-05)
9999

100100

101-
class TestViewAsRealAPI(unittest.TestCase):
101+
class TestAsRealAPI(unittest.TestCase):
102102
def setUp(self):
103103
self.x = np.random.randn(10, 10) + 1j * np.random.randn(10, 10)
104104
self.out = ref_view_as_real(self.x)
@@ -121,7 +121,7 @@ def test_static(self):
121121
np.testing.assert_allclose(self.out, out_np, rtol=1e-05)
122122

123123

124-
class TestViewAsRealAPI_ZeroSize(unittest.TestCase):
124+
class TestAsRealAPI_ZeroSize(unittest.TestCase):
125125
def setUp(self):
126126
self.x = np.random.randn(10, 0) + 1j * np.random.randn(10, 0)
127127
self.out = ref_view_as_real(self.x)
@@ -137,5 +137,50 @@ def test_dygraph(self):
137137
np.testing.assert_allclose(x_tensor.grad.shape, x_tensor.shape)
138138

139139

140+
class TestViewAsComplexAPI(unittest.TestCase):
141+
def setUp(self):
142+
self.x = np.random.randn(10, 10, 2)
143+
self.out = ref_view_as_complex(self.x)
144+
145+
def test_dygraph(self):
146+
with dygraph.guard():
147+
x = paddle.to_tensor(self.x)
148+
out = paddle.view_as_complex(x)
149+
out_np = out.numpy()
150+
self.assertEqual(out.data_ptr(), x.data_ptr())
151+
np.testing.assert_allclose(self.out, out_np, rtol=1e-05)
152+
153+
154+
class TestViewAsRealAPI(unittest.TestCase):
155+
def setUp(self):
156+
self.x = np.random.randn(10, 10) + 1j * np.random.randn(10, 10)
157+
self.out = ref_view_as_real(self.x)
158+
159+
def test_dygraph(self):
160+
with dygraph.guard():
161+
x = paddle.to_tensor(self.x)
162+
out = paddle.view_as_real(x)
163+
out_np = out.numpy()
164+
self.assertEqual(out.data_ptr(), x.data_ptr())
165+
np.testing.assert_allclose(self.out, out_np, rtol=1e-05)
166+
167+
168+
class TestViewAsRealAPI_ZeroSize(unittest.TestCase):
169+
def setUp(self):
170+
self.x = np.random.randn(10, 0) + 1j * np.random.randn(10, 0)
171+
self.out = ref_view_as_real(self.x)
172+
173+
def test_dygraph(self):
174+
for place in get_places():
175+
with dygraph.guard(place):
176+
x_tensor = paddle.to_tensor(self.x)
177+
x_tensor.stop_gradient = False
178+
out = paddle.view_as_real(x_tensor)
179+
np.testing.assert_allclose(self.out, out.numpy(), rtol=1e-05)
180+
self.assertEqual(out.data_ptr(), x_tensor.data_ptr())
181+
out.sum().backward()
182+
np.testing.assert_allclose(x_tensor.grad.shape, x_tensor.shape)
183+
184+
140185
if __name__ == "__main__":
141186
unittest.main()

0 commit comments

Comments
 (0)