Skip to content

Commit 17c568a

Browse files
view as real
1 parent 5b44422 commit 17c568a

File tree

3 files changed

+56
-5
lines changed

3 files changed

+56
-5
lines changed

python/paddle/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,9 @@
289289
)
290290
from .tensor.manipulation import (
291291
as_complex,
292+
as_complex as view_as_complex,
292293
as_real,
294+
as_real as view_as_real,
293295
as_strided,
294296
atleast_1d,
295297
atleast_2d,
@@ -1159,7 +1161,9 @@
11591161
'acosh',
11601162
'atanh',
11611163
'as_complex',
1164+
'view_as_complex',
11621165
'as_real',
1166+
'view_as_real',
11631167
'diff',
11641168
'angle',
11651169
'fmax',

python/paddle/tensor/manipulation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6188,6 +6188,7 @@ def _var_to_list(var):
61886188
return out
61896189

61906190

6191+
@ParamAliasDecorator({"x": ["input"]})
61916192
def as_complex(x: Tensor, name: str | None = None) -> Tensor:
61926193
"""Transform a real tensor to a complex tensor.
61936194
@@ -6241,6 +6242,7 @@ def as_complex(x: Tensor, name: str | None = None) -> Tensor:
62416242
return out
62426243

62436244

6245+
@ParamAliasDecorator({"x": ["input"]})
62446246
def as_real(x: Tensor, name: str | None = None) -> Tensor:
62456247
"""Transform a complex tensor to a real tensor.
62466248

test/legacy_test/test_complex_view_op.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def ref_view_as_real(x):
3333
return np.stack([x.real, x.imag], -1)
3434

3535

36-
class TestViewAsComplexOp(OpTest):
36+
class TestAsComplexOp(OpTest):
3737
def setUp(self):
3838
self.op_type = "as_complex"
3939
self.python_api = paddle.as_complex
@@ -53,7 +53,7 @@ def test_check_grad(self):
5353
)
5454

5555

56-
class TestViewAsRealOp(OpTest):
56+
class TestAsRealOp(OpTest):
5757
def setUp(self):
5858
self.op_type = "as_real"
5959
real = np.random.randn(10, 10).astype("float64")
@@ -75,7 +75,7 @@ def test_check_grad(self):
7575
)
7676

7777

78-
class TestViewAsComplexAPI(unittest.TestCase):
78+
class TestAsComplexAPI(unittest.TestCase):
7979
def setUp(self):
8080
self.x = np.random.randn(10, 10, 2)
8181
self.out = ref_view_as_complex(self.x)
@@ -98,7 +98,7 @@ def test_static(self):
9898
np.testing.assert_allclose(self.out, out_np, rtol=1e-05)
9999

100100

101-
class TestViewAsRealAPI(unittest.TestCase):
101+
class TestAsRealAPI(unittest.TestCase):
102102
def setUp(self):
103103
self.x = np.random.randn(10, 10) + 1j * np.random.randn(10, 10)
104104
self.out = ref_view_as_real(self.x)
@@ -121,7 +121,7 @@ def test_static(self):
121121
np.testing.assert_allclose(self.out, out_np, rtol=1e-05)
122122

123123

124-
class TestViewAsRealAPI_ZeroSize(unittest.TestCase):
124+
class TestAsRealAPI_ZeroSize(unittest.TestCase):
125125
def setUp(self):
126126
self.x = np.random.randn(10, 0) + 1j * np.random.randn(10, 0)
127127
self.out = ref_view_as_real(self.x)
@@ -137,5 +137,50 @@ def test_dygraph(self):
137137
np.testing.assert_allclose(x_tensor.grad.shape, x_tensor.shape)
138138

139139

140+
class TestViewAsComplexAPI(unittest.TestCase):
141+
def setUp(self):
142+
self.x = np.random.randn(10, 10, 2)
143+
self.out = ref_view_as_complex(self.x)
144+
145+
def test_dygraph(self):
146+
with dygraph.guard():
147+
x = paddle.to_tensor(self.x)
148+
out = paddle.view_as_complex(x)
149+
out_np = out.numpy()
150+
self.assertEqual(out.data_ptr(), x.data_ptr())
151+
np.testing.assert_allclose(self.out, out_np, rtol=1e-05)
152+
153+
154+
class TestViewAsRealAPI(unittest.TestCase):
155+
def setUp(self):
156+
self.x = np.random.randn(10, 10) + 1j * np.random.randn(10, 10)
157+
self.out = ref_view_as_real(self.x)
158+
159+
def test_dygraph(self):
160+
with dygraph.guard():
161+
x = paddle.to_tensor(self.x)
162+
out = paddle.view_as_real(x)
163+
out_np = out.numpy()
164+
self.assertEqual(out.data_ptr(), x.data_ptr())
165+
np.testing.assert_allclose(self.out, out_np, rtol=1e-05)
166+
167+
168+
class TestViewAsRealAPI_ZeroSize(unittest.TestCase):
169+
def setUp(self):
170+
self.x = np.random.randn(10, 0) + 1j * np.random.randn(10, 0)
171+
self.out = ref_view_as_real(self.x)
172+
173+
def test_dygraph(self):
174+
for place in get_places():
175+
with dygraph.guard(place):
176+
x_tensor = paddle.to_tensor(self.x)
177+
x_tensor.stop_gradient = False
178+
out = paddle.view_as_real(x_tensor)
179+
np.testing.assert_allclose(self.out, out.numpy(), rtol=1e-05)
180+
self.assertEqual(out.data_ptr(), x_tensor.data_ptr())
181+
out.sum().backward()
182+
np.testing.assert_allclose(x_tensor.grad.shape, x_tensor.shape)
183+
184+
140185
if __name__ == "__main__":
141186
unittest.main()

0 commit comments

Comments
 (0)