Skip to content

Commit ffeb8a0

Browse files
authored
Merge pull request #8 from nimobeeren/upload-avatar-image
Upload avatar image
2 parents f2bfbcb + b6e36a4 commit ffeb8a0

24 files changed

+1240
-369
lines changed

api/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ This will start a development server on `http://localhost:8000`.
1818

1919
### Adding test data
2020

21-
When you first start the backend, the database will be empty. To add some test data, you should set the `AUTH0_SEED_USER_ID` environment variable, then run the seed script:
21+
When you first start the backend, the database will be empty. To add some test data, you should set the `AUTH0_SEED_USER_ID` environment variable (in your `.env` file or your shell), then run the seed script:
2222

2323
```bash
24-
uv run -m dressme.db.seed
24+
uv run seed
2525
```
2626

2727
This will add some wearables to the database stored in the `dressme.db` file.

api/pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,11 @@ dev = [
2828
"pyright>=1.1.399",
2929
]
3030

31+
[project.scripts]
32+
seed = "dressme.db.seed:seed"
33+
3134
[tool.uv]
32-
package = false
35+
package = true
3336

3437
[build-system]
3538
requires = ["hatchling"]

api/src/dressme/db/models.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515

1616
class User(SQLModel, table=True):
1717
id: UUID = Field(default_factory=uuid4, primary_key=True)
18-
auth0_user_id: str = Field(index=True)
19-
avatar_image_id: UUID = Field(foreign_key="avatarimage.id", index=True)
18+
auth0_user_id: str = Field(index=True, unique=True)
19+
avatar_image_id: Optional[UUID] = Field(
20+
foreign_key="avatarimage.id", index=True, default=None
21+
)
2022
avatar_image: Optional["AvatarImage"] = Relationship()
2123
outfits: list["Outfit"] = Relationship(back_populates="user")
2224
wearables: list["Wearable"] = Relationship(back_populates="user")

api/src/dressme/db/seed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
}
6363

6464
# Path to the repo root
65-
ROOT_PATH = Path(__file__).parent.parent.parent.parent
65+
ROOT_PATH = Path(__file__).parent.parent.parent.parent.parent
6666

6767

6868
def seed():

api/src/dressme/main.py

Lines changed: 77 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import io
22
from contextlib import asynccontextmanager
3-
from pathlib import Path
3+
import logging
44
from typing import Annotated, Any, Literal, Sequence, cast
55
from urllib.parse import urlparse
66
from uuid import UUID
@@ -24,6 +24,7 @@
2424
from PIL import Image
2525
from pydantic import BaseModel, Field
2626
from replicate.client import Client # type: ignore
27+
from sqlalchemy.exc import IntegrityError
2728
from sqlalchemy.orm import joinedload
2829
from sqlmodel import Session, select
2930

@@ -35,6 +36,12 @@
3536
settings = get_settings()
3637
replicate = Client(api_token=settings.REPLICATE_API_TOKEN)
3738

39+
logging.basicConfig(
40+
level=logging.INFO,
41+
format="%(asctime)s.%(msecs)03d %(levelname)s %(message)s",
42+
datefmt="%Y-%m-%d %H:%M:%S",
43+
)
44+
3845

3946
def get_session():
4047
with Session(db.engine) as session:
@@ -79,29 +86,69 @@ def get_current_user(
7986
).one_or_none()
8087

8188
if current_user is None:
82-
# Add avatar image
83-
# TODO: get avatar during onboarding flow (there is currently no way for new users to upload an avatar)
84-
ROOT_PATH = Path(__file__).parent.parent.parent.parent
85-
image_path = ROOT_PATH / Path("images/humans/model.jpg")
86-
with open(image_path, "rb") as image_file:
87-
avatar_image = db.AvatarImage(image_data=image_file.read())
88-
session.add(avatar_image)
89-
90-
# Add user
91-
print(f"Creating new user with auth0_user_id: {repr(auth0_user_id)}")
92-
current_user = db.User(
93-
auth0_user_id=auth0_user_id, avatar_image_id=avatar_image.id
94-
)
95-
session.add(current_user)
96-
session.commit()
89+
try:
90+
logging.info(f"Creating new user with auth0_user_id: {repr(auth0_user_id)}")
91+
current_user = db.User(auth0_user_id=auth0_user_id)
92+
session.add(current_user)
93+
session.commit()
94+
session.refresh(current_user)
95+
except IntegrityError:
96+
# Handle race condition: another request created the user concurrentl
97+
logging.error(
98+
f"User creation failed due to integrity error (likely race condition) for auth0_user_id: {repr(auth0_user_id)}"
99+
)
100+
session.rollback()
101+
current_user = session.exec(
102+
select(db.User).where(db.User.auth0_user_id == auth0_user_id)
103+
).one()
97104

98105
return current_user
99106

100107

101108
class User(BaseModel):
102109
id: UUID
103-
auth0_user_id: str
104-
avatar_image_url: str
110+
has_avatar_image: bool
111+
112+
113+
@app.get("/users/me")
114+
def get_me(
115+
*,
116+
current_user: db.User = Depends(get_current_user),
117+
) -> User:
118+
return User(
119+
id=current_user.id,
120+
has_avatar_image=current_user.avatar_image is not None,
121+
)
122+
123+
124+
@app.put("/images/avatars/me")
125+
def update_avatar_image(
126+
*,
127+
image: UploadFile,
128+
session: Session = Depends(get_session),
129+
current_user: db.User = Depends(get_current_user),
130+
):
131+
# Check if the user already has an avatar image
132+
if current_user.avatar_image is not None:
133+
raise HTTPException(
134+
status_code=status.HTTP_400_BAD_REQUEST,
135+
detail="It's currently not possible to replace an existing avatar image.",
136+
)
137+
138+
# Convert the image to JPG and compress
139+
img = Image.open(image.file)
140+
compressed_img_buf = io.BytesIO()
141+
img.convert("RGB").save(compressed_img_buf, format="JPEG", quality=75)
142+
143+
# Create a new avatar image with the compressed data
144+
new_avatar_image = db.AvatarImage(image_data=compressed_img_buf.getvalue())
145+
session.add(new_avatar_image)
146+
147+
# Update the user's avatar image
148+
current_user.avatar_image = new_avatar_image
149+
session.commit()
150+
151+
return Response(status_code=status.HTTP_200_OK)
105152

106153

107154
class Wearable(BaseModel):
@@ -201,7 +248,7 @@ def create_woa_image(*, wearable_id: UUID, user_id: UUID):
201248
"""
202249

203250
with Session(db.engine) as session:
204-
print("Starting WOA generation")
251+
logging.info("Starting WOA generation")
205252
user = session.exec(select(db.User).where(db.User.id == user_id)).one()
206253
wearable = session.exec(
207254
select(db.Wearable)
@@ -210,7 +257,7 @@ def create_woa_image(*, wearable_id: UUID, user_id: UUID):
210257
).one()
211258

212259
# Generate an image of the avatar wearing the wearable
213-
print("Generating WOA image")
260+
logging.info("Generating WOA image")
214261
assert wearable.wearable_image is not None
215262
assert user.avatar_image is not None
216263
woa_image_url_raw = replicate.run(
@@ -225,7 +272,7 @@ def create_woa_image(*, wearable_id: UUID, user_id: UUID):
225272
woa_image_url = urlparse(str(woa_image_url_raw)).geturl()
226273

227274
# Get a mask of the wearable on the avatar using an image segmentation model
228-
print("Generating mask")
275+
logging.info("Generating mask")
229276
mask_results = replicate.run(
230277
"schananas/grounded_sam:ee871c19efb1941f55f66a3d7d960428c8a5afcb77449547fe8e5a3ab9ebc21c",
231278
input={
@@ -247,14 +294,14 @@ def create_woa_image(*, wearable_id: UUID, user_id: UUID):
247294
if mask_image_url is None:
248295
raise ValueError("Could not get mask URL")
249296

250-
print("Fetching results")
297+
logging.info("Fetching results")
251298
woa_image_response = requests.get(woa_image_url, stream=True)
252299
woa_image_response.raise_for_status()
253300

254301
mask_image_response = requests.get(mask_image_url, stream=True)
255302
mask_image_response.raise_for_status()
256303

257-
print("Saving results to DB")
304+
logging.info("Saving results to DB")
258305
assert user.avatar_image is not None
259306
assert wearable.wearable_image is not None
260307
woa_image = db.WearableOnAvatarImage(
@@ -265,7 +312,7 @@ def create_woa_image(*, wearable_id: UUID, user_id: UUID):
265312
)
266313
session.add(woa_image)
267314
session.commit()
268-
print("Finished generating WOA image")
315+
logging.info("Finished generating WOA image")
269316

270317

271318
@app.post("/wearables", status_code=status.HTTP_201_CREATED)
@@ -298,7 +345,12 @@ def create_wearables(
298345
for item_category, item_description, item_image in zip(
299346
category, description, image, strict=True
300347
):
301-
wearable_image = db.WearableImage(image_data=item_image.file.read())
348+
# Convert the image to JPG and compress
349+
img = Image.open(item_image.file)
350+
compressed_img_buf = io.BytesIO()
351+
img.convert("RGB").save(compressed_img_buf, format="JPEG", quality=75)
352+
353+
wearable_image = db.WearableImage(image_data=compressed_img_buf.getvalue())
302354
session.add(wearable_image)
303355

304356
wearable = db.Wearable(

0 commit comments

Comments
 (0)