1
1
import io
2
2
from contextlib import asynccontextmanager
3
- from pathlib import Path
3
+ import logging
4
4
from typing import Annotated , Any , Literal , Sequence , cast
5
5
from urllib .parse import urlparse
6
6
from uuid import UUID
24
24
from PIL import Image
25
25
from pydantic import BaseModel , Field
26
26
from replicate .client import Client # type: ignore
27
+ from sqlalchemy .exc import IntegrityError
27
28
from sqlalchemy .orm import joinedload
28
29
from sqlmodel import Session , select
29
30
35
36
settings = get_settings ()
36
37
replicate = Client (api_token = settings .REPLICATE_API_TOKEN )
37
38
39
+ logging .basicConfig (
40
+ level = logging .INFO ,
41
+ format = "%(asctime)s.%(msecs)03d %(levelname)s %(message)s" ,
42
+ datefmt = "%Y-%m-%d %H:%M:%S" ,
43
+ )
44
+
38
45
39
46
def get_session ():
40
47
with Session (db .engine ) as session :
@@ -79,29 +86,69 @@ def get_current_user(
79
86
).one_or_none ()
80
87
81
88
if current_user is None :
82
- # Add avatar image
83
- # TODO: get avatar during onboarding flow (there is currently no way for new users to upload an avatar )
84
- ROOT_PATH = Path ( __file__ ). parent . parent . parent . parent
85
- image_path = ROOT_PATH / Path ( "images/humans/model.jpg" )
86
- with open ( image_path , "rb" ) as image_file :
87
- avatar_image = db . AvatarImage ( image_data = image_file . read () )
88
- session . add ( avatar_image )
89
-
90
- # Add user
91
- print ( f"Creating new user with auth0_user_id: { repr (auth0_user_id )} ")
92
- current_user = db . User (
93
- auth0_user_id = auth0_user_id , avatar_image_id = avatar_image . id
94
- )
95
- session . add ( current_user )
96
- session . commit ()
89
+ try :
90
+ logging . info ( f"Creating new user with auth0_user_id: { repr ( auth0_user_id ) } " )
91
+ current_user = db . User ( auth0_user_id = auth0_user_id )
92
+ session . add ( current_user )
93
+ session . commit ()
94
+ session . refresh ( current_user )
95
+ except IntegrityError :
96
+ # Handle race condition: another request created the user concurrentl
97
+ logging . error (
98
+ f"User creation failed due to integrity error (likely race condition) for auth0_user_id: { repr (auth0_user_id )} "
99
+ )
100
+ session . rollback ()
101
+ current_user = session . exec (
102
+ select ( db . User ). where ( db . User . auth0_user_id == auth0_user_id )
103
+ ). one ()
97
104
98
105
return current_user
99
106
100
107
101
108
class User (BaseModel ):
102
109
id : UUID
103
- auth0_user_id : str
104
- avatar_image_url : str
110
+ has_avatar_image : bool
111
+
112
+
113
+ @app .get ("/users/me" )
114
+ def get_me (
115
+ * ,
116
+ current_user : db .User = Depends (get_current_user ),
117
+ ) -> User :
118
+ return User (
119
+ id = current_user .id ,
120
+ has_avatar_image = current_user .avatar_image is not None ,
121
+ )
122
+
123
+
124
+ @app .put ("/images/avatars/me" )
125
+ def update_avatar_image (
126
+ * ,
127
+ image : UploadFile ,
128
+ session : Session = Depends (get_session ),
129
+ current_user : db .User = Depends (get_current_user ),
130
+ ):
131
+ # Check if the user already has an avatar image
132
+ if current_user .avatar_image is not None :
133
+ raise HTTPException (
134
+ status_code = status .HTTP_400_BAD_REQUEST ,
135
+ detail = "It's currently not possible to replace an existing avatar image." ,
136
+ )
137
+
138
+ # Convert the image to JPG and compress
139
+ img = Image .open (image .file )
140
+ compressed_img_buf = io .BytesIO ()
141
+ img .convert ("RGB" ).save (compressed_img_buf , format = "JPEG" , quality = 75 )
142
+
143
+ # Create a new avatar image with the compressed data
144
+ new_avatar_image = db .AvatarImage (image_data = compressed_img_buf .getvalue ())
145
+ session .add (new_avatar_image )
146
+
147
+ # Update the user's avatar image
148
+ current_user .avatar_image = new_avatar_image
149
+ session .commit ()
150
+
151
+ return Response (status_code = status .HTTP_200_OK )
105
152
106
153
107
154
class Wearable (BaseModel ):
@@ -201,7 +248,7 @@ def create_woa_image(*, wearable_id: UUID, user_id: UUID):
201
248
"""
202
249
203
250
with Session (db .engine ) as session :
204
- print ("Starting WOA generation" )
251
+ logging . info ("Starting WOA generation" )
205
252
user = session .exec (select (db .User ).where (db .User .id == user_id )).one ()
206
253
wearable = session .exec (
207
254
select (db .Wearable )
@@ -210,7 +257,7 @@ def create_woa_image(*, wearable_id: UUID, user_id: UUID):
210
257
).one ()
211
258
212
259
# Generate an image of the avatar wearing the wearable
213
- print ("Generating WOA image" )
260
+ logging . info ("Generating WOA image" )
214
261
assert wearable .wearable_image is not None
215
262
assert user .avatar_image is not None
216
263
woa_image_url_raw = replicate .run (
@@ -225,7 +272,7 @@ def create_woa_image(*, wearable_id: UUID, user_id: UUID):
225
272
woa_image_url = urlparse (str (woa_image_url_raw )).geturl ()
226
273
227
274
# Get a mask of the wearable on the avatar using an image segmentation model
228
- print ("Generating mask" )
275
+ logging . info ("Generating mask" )
229
276
mask_results = replicate .run (
230
277
"schananas/grounded_sam:ee871c19efb1941f55f66a3d7d960428c8a5afcb77449547fe8e5a3ab9ebc21c" ,
231
278
input = {
@@ -247,14 +294,14 @@ def create_woa_image(*, wearable_id: UUID, user_id: UUID):
247
294
if mask_image_url is None :
248
295
raise ValueError ("Could not get mask URL" )
249
296
250
- print ("Fetching results" )
297
+ logging . info ("Fetching results" )
251
298
woa_image_response = requests .get (woa_image_url , stream = True )
252
299
woa_image_response .raise_for_status ()
253
300
254
301
mask_image_response = requests .get (mask_image_url , stream = True )
255
302
mask_image_response .raise_for_status ()
256
303
257
- print ("Saving results to DB" )
304
+ logging . info ("Saving results to DB" )
258
305
assert user .avatar_image is not None
259
306
assert wearable .wearable_image is not None
260
307
woa_image = db .WearableOnAvatarImage (
@@ -265,7 +312,7 @@ def create_woa_image(*, wearable_id: UUID, user_id: UUID):
265
312
)
266
313
session .add (woa_image )
267
314
session .commit ()
268
- print ("Finished generating WOA image" )
315
+ logging . info ("Finished generating WOA image" )
269
316
270
317
271
318
@app .post ("/wearables" , status_code = status .HTTP_201_CREATED )
@@ -298,7 +345,12 @@ def create_wearables(
298
345
for item_category , item_description , item_image in zip (
299
346
category , description , image , strict = True
300
347
):
301
- wearable_image = db .WearableImage (image_data = item_image .file .read ())
348
+ # Convert the image to JPG and compress
349
+ img = Image .open (item_image .file )
350
+ compressed_img_buf = io .BytesIO ()
351
+ img .convert ("RGB" ).save (compressed_img_buf , format = "JPEG" , quality = 75 )
352
+
353
+ wearable_image = db .WearableImage (image_data = compressed_img_buf .getvalue ())
302
354
session .add (wearable_image )
303
355
304
356
wearable = db .Wearable (
0 commit comments