Skip to content

Commit 64e141f

Browse files
fix(fixes-metadata-cleanup): fixes metadata cleanup problem also introduces an ability to keep all runtime metadata
1 parent e370b77 commit 64e141f

File tree

1 file changed

+46
-37
lines changed

1 file changed

+46
-37
lines changed

predictor/app.py

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -26,41 +26,48 @@ async def predict(
2626
geojson=None,
2727
merge_input_images_to_single_image=False,
2828
get_predictions_as_points=True,
29-
ortho_skew_tolerance_deg: int = 15, # angle (0,45> degrees
30-
ortho_max_angle_change_deg: int = 15, # angle (0,45> degrees
29+
ortho_skew_tolerance_deg=15,
30+
ortho_max_angle_change_deg=15,
3131
):
32-
"""
33-
Parameters:
34-
bbox : Bounding box of the area you want to run prediction on
35-
model_path : Path of your downloaded model checkpoint
36-
zoom_level : Zoom level of the tiles to be used for prediction
37-
tms_url : Your Image URL on which you want to detect feature
38-
tile_size : Optional >> Tile size to be used in pixel default : 256*256
39-
base_path : Optional >> Basepath for your working dir of prediction
40-
confidence: Optional >> Threshold probability for filtering out low-confidence predictions, Defaults to 0.5
41-
area_threshold (float, optional): Threshold for filtering polygon areas. Defaults to 3 sqm.
42-
tolerance (float, optional): Tolerance parameter for simplifying polygons. Defaults to 0.5 m. Percentage Tolerance = (Tolerance in Meters / Arc Length in Meters ​)×100
32+
"""Detect buildings using ML model and return as GeoJSON.
4333
34+
Parameters:
35+
model_path: Path of downloaded model checkpoint
36+
zoom_level: Zoom level for prediction tiles
37+
tms_url: Image URL for feature detection
38+
output_path: Directory to save prediction results (temporary UUID dir if None)
39+
confidence: Threshold for filtering predictions (0-1)
40+
area_threshold: Minimum polygon area in sqm (default: 3)
41+
tolerance: Simplification tolerance in meters (default: 0.5)
42+
remove_metadata: Whether to delete intermediate files after processing
43+
orthogonalize: Whether to square building corners
44+
bbox: Bounding box for prediction area
45+
geojson: GeoJSON object for prediction area
46+
merge_input_images_to_single_image: Whether to merge source images
47+
get_predictions_as_points: Whether to generate point representations
48+
ortho_skew_tolerance_deg: Max skew angle for orthogonalization (0-45)
49+
ortho_max_angle_change_deg: Max corner adjustment angle (0-45)
4450
"""
4551
if not bbox and not geojson:
4652
raise ValueError("Either bbox or geojson must be provided")
4753
if confidence < 0 or confidence > 1:
4854
raise ValueError("Confidence must be between 0 and 1")
49-
if output_path:
50-
base_path = output_path
51-
else:
52-
base_path = os.path.join(os.getcwd(), "predictions", str(uuid.uuid4()))
5355

56+
base_path = output_path or os.path.join(
57+
os.getcwd(), "predictions", str(uuid.uuid4())
58+
)
5459
model_path = download_or_validate_model(model_path)
60+
5561
os.makedirs(base_path, exist_ok=True)
56-
meta_path = os.path.join(output_path, "meta")
57-
results_path = os.path.join(output_path, "results")
62+
meta_path, results_path = (
63+
os.path.join(output_path, "meta"),
64+
os.path.join(output_path, "results"),
65+
)
5866
os.makedirs(meta_path, exist_ok=True)
5967
os.makedirs(results_path, exist_ok=True)
6068

6169
image_download_path = os.path.join(meta_path, "image")
6270
os.makedirs(image_download_path, exist_ok=True)
63-
6471
image_download_path = await TMSDownloader.download_tiles(
6572
bbox=bbox,
6673
geojson=geojson,
@@ -69,69 +76,71 @@ async def predict(
6976
out=image_download_path,
7077
georeference=True,
7178
crs="3857",
72-
# dump=True,
7379
)
80+
7481
if merge_input_images_to_single_image:
7582
merge_rasters(
7683
image_download_path, os.path.join(meta_path, "merged_image_chips.tif")
7784
)
7885

7986
prediction_path = os.path.join(meta_path, "prediction")
8087
os.makedirs(prediction_path, exist_ok=True)
81-
8288
prediction_path = run_prediction(
8389
model_path,
8490
image_download_path,
8591
prediction_path=prediction_path,
8692
confidence=confidence,
8793
crs="3857",
8894
)
89-
start = time.time()
9095

96+
start = time.time()
9197
geojson_path = os.path.join(results_path, "geojson")
9298
os.makedirs(geojson_path, exist_ok=True)
93-
prediction_geojson_path = os.path.join(geojson_path, "predictions.geojson")
94-
9599
prediction_merged_mask_path = os.path.join(meta_path, "merged_prediction_mask.tif")
96100
os.makedirs(os.path.dirname(prediction_merged_mask_path), exist_ok=True)
97101

98-
# Merge rasters
99102
merge_rasters(prediction_path, prediction_merged_mask_path)
100103
tmp_dir = os.path.join(base_path, "tmp")
101-
converter = VectorizeMasks(
104+
gdf = VectorizeMasks(
102105
simplify_tolerance=tolerance,
103106
min_area=area_threshold,
104107
orthogonalize=orthogonalize,
105108
tmp_dir=tmp_dir,
106109
ortho_skew_tolerance_deg=ortho_skew_tolerance_deg,
107110
ortho_max_angle_change_deg=ortho_max_angle_change_deg,
111+
).convert(
112+
prediction_merged_mask_path, os.path.join(geojson_path, "predictions.geojson")
108113
)
109-
gdf = converter.convert(prediction_merged_mask_path, prediction_geojson_path)
114+
110115
shutil.rmtree(tmp_dir)
111116
print(f"It took {round(time.time() - start)} sec to extract polygons")
112117

113118
if gdf.crs and gdf.crs != "EPSG:4326":
114119
gdf = gdf.to_crs("EPSG:4326")
115120
elif not gdf.crs:
116-
# if not defined assume its 3857 because above 3857 is hardcoded
117121
gdf.set_crs("EPSG:3857", inplace=True)
118122
gdf = gdf.to_crs("EPSG:4326")
119123

120-
gdf["building"] = "yes"
121-
gdf["source"] = "fAIr"
124+
gdf["building"], gdf["source"] = "yes", "fAIr"
125+
126+
if remove_metadata:
127+
shutil.rmtree(meta_path)
122128

123129
if get_predictions_as_points:
124-
gdf_representative_points = gdf.copy()
125-
gdf_representative_points.geometry = gdf_representative_points.geometry.apply(
130+
gdf_points = gdf.copy()
131+
gdf_points.geometry = gdf_points.geometry.apply(
126132
lambda geom: geom.representative_point()
127133
)
128-
gdf_representative_points.to_file(
129-
os.path.join(geojson_path, "prediction_points.geojson"), driver="GeoJSON"
134+
gdf_points.to_file(
135+
os.path.join(geojson_path, "predictions_points.geojson"), driver="GeoJSON"
130136
)
137+
if not output_path:
138+
shutil.rmtree(base_path)
139+
return json.loads(gdf_points.to_json())
140+
131141
prediction_geojson_data = json.loads(gdf.to_json())
132142

133-
if remove_metadata:
134-
shutil.rmtree(meta_path)
135143
if not output_path:
136144
shutil.rmtree(base_path)
145+
137146
return prediction_geojson_data

0 commit comments

Comments
 (0)