1
+ import json
1
2
import logging
3
+ import os
2
4
import random
3
5
import re
4
6
import socket
5
7
import string
6
8
import subprocess
9
+ from pathlib import Path
7
10
8
11
import numpy as np
12
+ import psutil
9
13
import requests
14
+ from filelock import FileLock
10
15
from libcloud .compute .providers import get_driver
11
16
from libcloud .compute .types import Provider
12
17
@@ -81,7 +86,7 @@ def still_exists(self):
81
86
82
87
@property
83
88
def free (self ):
84
- return not self ._in_use
89
+ return not self ._in_use and not self . manager . lockfile . check_if_in_use ( self )
85
90
86
91
@property
87
92
def usable (self ):
@@ -131,9 +136,73 @@ def delete(self, background=True):
131
136
132
137
def in_use (self ):
133
138
self ._in_use = True
139
+ self .manager .lockfile .register_in_use (self )
134
140
135
141
def release (self ):
142
+ assert self ._in_use
136
143
self ._in_use = False
144
+ self .manager .lockfile .register_free (self )
145
+
146
+
147
+ class TPULockFile :
148
+
149
+ def __init__ (self , filepath ):
150
+ self .filepath = Path (filepath ).expanduser ()
151
+ self .lockpath = Path (filepath + ".lock" ).expanduser ()
152
+ self .filelock = FileLock (self .lockpath )
153
+
154
+ if not self .filepath .exists ():
155
+ self .filepath .touch ()
156
+ if not self .lockpath .exists ():
157
+ self .lockpath .touch ()
158
+
159
+ def _write_registry (self , registry ):
160
+ f = open (self .filepath , "w" )
161
+ f .write (json .dumps (registry ))
162
+ f .close ()
163
+
164
+ def register_free (self , tpu ):
165
+ with self .filelock :
166
+ f = open (self .filepath , "r" )
167
+ f_raw = f .read ()
168
+ tpu_registry = json .loads (f_raw ) if f_raw else {}
169
+ if tpu .name not in tpu_registry :
170
+ return
171
+
172
+ del tpu_registry [tpu .name ]
173
+ f .close ()
174
+ self ._write_registry (tpu_registry )
175
+
176
+ def register_in_use (self , tpu ):
177
+ with self .filelock :
178
+ f = open (self .filepath , "r" )
179
+ f_raw = f .read ()
180
+ tpu_registry = json .loads (f_raw ) if f_raw else {}
181
+ if tpu .name in tpu_registry :
182
+ if os .getpid () == tpu_registry [tpu .name ]:
183
+ pass
184
+ elif psutil .pid_exists (tpu_registry [tpu .name ]):
185
+ raise Exception ("TPU is already registered" )
186
+ else :
187
+ logger .warn (f"Forcefully acquiring TPU { tpu .name } from dead pid { tpu_registry [tpu .name ]} ." )
188
+ tpu_registry [tpu .name ] = os .getpid ()
189
+ f .close ()
190
+ self ._write_registry (tpu_registry )
191
+
192
+ def check_if_in_use (self , tpu ):
193
+ with self .filelock :
194
+ f = open (self .filepath , "r" )
195
+ f_raw = f .read ()
196
+ tpu_registry = json .loads (f_raw ) if f_raw else {}
197
+ if tpu .name in tpu_registry :
198
+ if psutil .pid_exists (tpu_registry [tpu .name ]):
199
+ return True
200
+ else :
201
+ logger .warn (f"Removing TPU { tpu .name } from dead pid { tpu_registry [tpu .name ]} ." )
202
+ del tpu_registry [tpu .name ]
203
+ self ._write_registry (tpu_registry )
204
+
205
+ return False
137
206
138
207
139
208
class TPUManager (env .ResourceManager ):
@@ -156,6 +225,8 @@ def __init__(self, instance):
156
225
lines = r .split ("\n " )[1 :]
157
226
lines = list (filter (lambda l : l != "" , lines ))
158
227
self .zone = lines [0 ].split ()[1 ]
228
+ from cloud import socket_path
229
+ self .lockfile = TPULockFile (os .path .join ("~" , ".tpu_registry" ))
159
230
self .refresh ()
160
231
161
232
@property
0 commit comments