@@ -104,26 +104,21 @@ def test_invalid_tensor_dict_raises_error():
104
104
def test_save_file_and_load_file_consistency ():
105
105
tensor_dict = create_gpt2_numpy_dict (1 )
106
106
filename = ""
107
- with tempfile .NamedTemporaryFile (delete = False ) as tmp :
107
+ loaded_dict = {}
108
+ with tempfile .NamedTemporaryFile () as tmp :
108
109
filename = tmp .name
109
-
110
- try :
111
- save_file (tensor_dict , filename )
112
- loaded_dict = load_file (filename )
113
-
114
- for key , value in tensor_dict .items ():
115
- assert _compare_np_array (loaded_dict [key ], value )
116
- finally :
117
- if os .path .exists (filename ):
118
- os .remove (filename )
110
+ try :
111
+ save_file (tensor_dict , filename )
112
+ loaded_dict = load_file (filename )
113
+ finally :
114
+ for key , value in tensor_dict .items ():
115
+ assert _compare_np_array (loaded_dict [key ], value )
119
116
120
117
121
118
def test_safe_open_access_and_metadata ():
122
119
tensor_dict = create_gpt2_numpy_dict (1 )
123
- with tempfile .NamedTemporaryFile (delete = False ) as tmp :
120
+ with tempfile .NamedTemporaryFile () as tmp :
124
121
filename = tmp .name
125
-
126
- try :
127
122
# save file into tempfile
128
123
save_file (tensor_dict , filename )
129
124
@@ -132,22 +127,15 @@ def test_safe_open_access_and_metadata():
132
127
assert model .get_tensor ("h.0.ln_1.weight" ) is not None
133
128
assert model .get_tensor ("h.0.ln_1.bias" ) is not None
134
129
assert model .metadata () is None
135
- finally :
136
- if os .path .exists (filename ):
137
- os .remove (filename )
138
130
139
131
140
132
def test_safe_open_access_with_metadata ():
141
133
tensor_dict = create_gpt2_numpy_dict (1 )
142
- with tempfile .NamedTemporaryFile (delete = False ) as tmp :
134
+ with tempfile .NamedTemporaryFile () as tmp :
143
135
filename = tmp .name
144
136
145
- try :
146
137
save_file (tensor_dict , filename , metadata = {"hello" : "world" })
147
138
with safe_open (filename , "numpy" ) as model :
148
139
assert model .get_tensor ("h.0.ln_1.weight" ) is not None
149
140
assert model .get_tensor ("h.0.ln_1.bias" ) is not None
150
141
assert model .metadata ()["hello" ] == "world"
151
- finally :
152
- if os .path .exists (filename ):
153
- os .remove (filename )
0 commit comments