Skip to content

Commit 9770619

Browse files
Add view_as_complex and view_as_real
1 parent 5b44422 commit 9770619

File tree

3 files changed

+56
-7
lines changed

3 files changed

+56
-7
lines changed

python/paddle/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,8 @@
363363
unstack,
364364
view,
365365
view_as,
366+
as_complex as view_as_complex,
367+
as_real as view_as_real,
366368
vsplit,
367369
vstack,
368370
)
@@ -1159,7 +1161,9 @@
11591161
'acosh',
11601162
'atanh',
11611163
'as_complex',
1164+
'view_as_complex',
11621165
'as_real',
1166+
'view_as_real',
11631167
'diff',
11641168
'angle',
11651169
'fmax',

python/paddle/tensor/manipulation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6187,7 +6187,7 @@ def _var_to_list(var):
61876187
out = x.matmul(y).reshape(shape_out)
61886188
return out
61896189

6190-
6190+
@ParamAliasDecorator({"x": ["input"]})
61916191
def as_complex(x: Tensor, name: str | None = None) -> Tensor:
61926192
"""Transform a real tensor to a complex tensor.
61936193
@@ -6240,7 +6240,7 @@ def as_complex(x: Tensor, name: str | None = None) -> Tensor:
62406240
)
62416241
return out
62426242

6243-
6243+
@ParamAliasDecorator({"x": ["input"]})
62446244
def as_real(x: Tensor, name: str | None = None) -> Tensor:
62456245
"""Transform a complex tensor to a real tensor.
62466246

test/legacy_test/test_complex_view_op.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def ref_view_as_real(x):
3333
return np.stack([x.real, x.imag], -1)
3434

3535

36-
class TestViewAsComplexOp(OpTest):
36+
class TestAsComplexOp(OpTest):
3737
def setUp(self):
3838
self.op_type = "as_complex"
3939
self.python_api = paddle.as_complex
@@ -53,7 +53,7 @@ def test_check_grad(self):
5353
)
5454

5555

56-
class TestViewAsRealOp(OpTest):
56+
class TestAsRealOp(OpTest):
5757
def setUp(self):
5858
self.op_type = "as_real"
5959
real = np.random.randn(10, 10).astype("float64")
@@ -75,7 +75,7 @@ def test_check_grad(self):
7575
)
7676

7777

78-
class TestViewAsComplexAPI(unittest.TestCase):
78+
class TestAsComplexAPI(unittest.TestCase):
7979
def setUp(self):
8080
self.x = np.random.randn(10, 10, 2)
8181
self.out = ref_view_as_complex(self.x)
@@ -98,7 +98,7 @@ def test_static(self):
9898
np.testing.assert_allclose(self.out, out_np, rtol=1e-05)
9999

100100

101-
class TestViewAsRealAPI(unittest.TestCase):
101+
class TestAsRealAPI(unittest.TestCase):
102102
def setUp(self):
103103
self.x = np.random.randn(10, 10) + 1j * np.random.randn(10, 10)
104104
self.out = ref_view_as_real(self.x)
@@ -121,7 +121,7 @@ def test_static(self):
121121
np.testing.assert_allclose(self.out, out_np, rtol=1e-05)
122122

123123

124-
class TestViewAsRealAPI_ZeroSize(unittest.TestCase):
124+
class TestAsRealAPI_ZeroSize(unittest.TestCase):
125125
def setUp(self):
126126
self.x = np.random.randn(10, 0) + 1j * np.random.randn(10, 0)
127127
self.out = ref_view_as_real(self.x)
@@ -137,5 +137,50 @@ def test_dygraph(self):
137137
np.testing.assert_allclose(x_tensor.grad.shape, x_tensor.shape)
138138

139139

140+
class TestViewAsComplexAPI(unittest.TestCase):
141+
def setUp(self):
142+
self.x = np.random.randn(10, 10, 2)
143+
self.out = ref_view_as_complex(self.x)
144+
145+
def test_dygraph(self):
146+
with dygraph.guard():
147+
x = paddle.to_tensor(self.x)
148+
out = paddle.view_as_complex(x)
149+
out_np = out.numpy()
150+
self.assertEqual(out.data_ptr(), x.data_ptr())
151+
np.testing.assert_allclose(self.out, out_np, rtol=1e-05)
152+
153+
154+
class TestViewAsRealAPI(unittest.TestCase):
155+
def setUp(self):
156+
self.x = np.random.randn(10, 10) + 1j * np.random.randn(10, 10)
157+
self.out = ref_view_as_real(self.x)
158+
159+
def test_dygraph(self):
160+
with dygraph.guard():
161+
x = paddle.to_tensor(self.x)
162+
out = paddle.view_as_real(x)
163+
out_np = out.numpy()
164+
self.assertEqual(out.data_ptr(), x.data_ptr())
165+
np.testing.assert_allclose(self.out, out_np, rtol=1e-05)
166+
167+
168+
class TestViewAsRealAPI_ZeroSize(unittest.TestCase):
169+
def setUp(self):
170+
self.x = np.random.randn(10, 0) + 1j * np.random.randn(10, 0)
171+
self.out = ref_view_as_real(self.x)
172+
173+
def test_dygraph(self):
174+
for place in get_places():
175+
with dygraph.guard(place):
176+
x_tensor = paddle.to_tensor(self.x)
177+
x_tensor.stop_gradient = False
178+
out = paddle.view_as_real(x_tensor)
179+
np.testing.assert_allclose(self.out, out.numpy(), rtol=1e-05)
180+
self.assertEqual(out.data_ptr(), x_tensor.data_ptr())
181+
out.sum().backward()
182+
np.testing.assert_allclose(x_tensor.grad.shape, x_tensor.shape)
183+
184+
140185
if __name__ == "__main__":
141186
unittest.main()

0 commit comments

Comments
 (0)