@@ -33,7 +33,7 @@ def ref_view_as_real(x):
33
33
return np .stack ([x .real , x .imag ], - 1 )
34
34
35
35
36
- class TestViewAsComplexOp (OpTest ):
36
+ class TestAsComplexOp (OpTest ):
37
37
def setUp (self ):
38
38
self .op_type = "as_complex"
39
39
self .python_api = paddle .as_complex
@@ -53,7 +53,7 @@ def test_check_grad(self):
53
53
)
54
54
55
55
56
- class TestViewAsRealOp (OpTest ):
56
+ class TestAsRealOp (OpTest ):
57
57
def setUp (self ):
58
58
self .op_type = "as_real"
59
59
real = np .random .randn (10 , 10 ).astype ("float64" )
@@ -75,7 +75,7 @@ def test_check_grad(self):
75
75
)
76
76
77
77
78
- class TestViewAsComplexAPI (unittest .TestCase ):
78
+ class TestAsComplexAPI (unittest .TestCase ):
79
79
def setUp (self ):
80
80
self .x = np .random .randn (10 , 10 , 2 )
81
81
self .out = ref_view_as_complex (self .x )
@@ -98,7 +98,7 @@ def test_static(self):
98
98
np .testing .assert_allclose (self .out , out_np , rtol = 1e-05 )
99
99
100
100
101
- class TestViewAsRealAPI (unittest .TestCase ):
101
+ class TestAsRealAPI (unittest .TestCase ):
102
102
def setUp (self ):
103
103
self .x = np .random .randn (10 , 10 ) + 1j * np .random .randn (10 , 10 )
104
104
self .out = ref_view_as_real (self .x )
@@ -121,7 +121,7 @@ def test_static(self):
121
121
np .testing .assert_allclose (self .out , out_np , rtol = 1e-05 )
122
122
123
123
124
- class TestViewAsRealAPI_ZeroSize (unittest .TestCase ):
124
+ class TestAsRealAPI_ZeroSize (unittest .TestCase ):
125
125
def setUp (self ):
126
126
self .x = np .random .randn (10 , 0 ) + 1j * np .random .randn (10 , 0 )
127
127
self .out = ref_view_as_real (self .x )
@@ -137,5 +137,50 @@ def test_dygraph(self):
137
137
np .testing .assert_allclose (x_tensor .grad .shape , x_tensor .shape )
138
138
139
139
140
+ class TestViewAsComplexAPI (unittest .TestCase ):
141
+ def setUp (self ):
142
+ self .x = np .random .randn (10 , 10 , 2 )
143
+ self .out = ref_view_as_complex (self .x )
144
+
145
+ def test_dygraph (self ):
146
+ with dygraph .guard ():
147
+ x = paddle .to_tensor (self .x )
148
+ out = paddle .view_as_complex (x )
149
+ out_np = out .numpy ()
150
+ self .assertEqual (out .data_ptr (), x .data_ptr ())
151
+ np .testing .assert_allclose (self .out , out_np , rtol = 1e-05 )
152
+
153
+
154
+ class TestViewAsRealAPI (unittest .TestCase ):
155
+ def setUp (self ):
156
+ self .x = np .random .randn (10 , 10 ) + 1j * np .random .randn (10 , 10 )
157
+ self .out = ref_view_as_real (self .x )
158
+
159
+ def test_dygraph (self ):
160
+ with dygraph .guard ():
161
+ x = paddle .to_tensor (self .x )
162
+ out = paddle .view_as_real (x )
163
+ out_np = out .numpy ()
164
+ self .assertEqual (out .data_ptr (), x .data_ptr ())
165
+ np .testing .assert_allclose (self .out , out_np , rtol = 1e-05 )
166
+
167
+
168
+ class TestViewAsRealAPI_ZeroSize (unittest .TestCase ):
169
+ def setUp (self ):
170
+ self .x = np .random .randn (10 , 0 ) + 1j * np .random .randn (10 , 0 )
171
+ self .out = ref_view_as_real (self .x )
172
+
173
+ def test_dygraph (self ):
174
+ for place in get_places ():
175
+ with dygraph .guard (place ):
176
+ x_tensor = paddle .to_tensor (self .x )
177
+ x_tensor .stop_gradient = False
178
+ out = paddle .view_as_real (x_tensor )
179
+ np .testing .assert_allclose (self .out , out .numpy (), rtol = 1e-05 )
180
+ self .assertEqual (out .data_ptr (), x_tensor .data_ptr ())
181
+ out .sum ().backward ()
182
+ np .testing .assert_allclose (x_tensor .grad .shape , x_tensor .shape )
183
+
184
+
140
185
if __name__ == "__main__" :
141
186
unittest .main ()
0 commit comments