11
11
from xgboost .sklearn import XGBRegressor
12
12
13
13
from . import constants
14
+ import quartz_solar_forecast
14
15
15
16
logger = logging .getLogger (__name__ )
16
17
@@ -36,6 +37,7 @@ class TryolabsSolarPowerPredictor:
36
37
Predicts solar power output for the given parameters.
37
38
"""
38
39
DATE_COLUMN = "date"
40
+ download_dir = os .path .dirname (quartz_solar_forecast .__file__ ) + "/models"
39
41
40
42
def _download_model (self , filename : str , repo_id : str , file_path : str ) -> str :
41
43
"""
@@ -56,12 +58,11 @@ def _download_model(self, filename: str, repo_id: str, file_path: str) -> str:
56
58
The path to the locally saved model file.
57
59
"""
58
60
# Use the project directory instead of the user's home directory
59
- download_dir = "/home/runner/work/Open-Source-Quartz-Solar-Forecast/Open-Source-Quartz-Solar-Forecast"
60
- os .makedirs (download_dir , exist_ok = True )
61
+ os .makedirs (self .download_dir , exist_ok = True )
61
62
62
- downloaded_file = hf_hub_download (repo_id = repo_id , filename = file_path , cache_dir = download_dir )
63
+ downloaded_file = hf_hub_download (repo_id = repo_id , filename = file_path , cache_dir = self . download_dir )
63
64
64
- target_path = os .path .join (download_dir , filename )
65
+ target_path = os .path .join (self . download_dir , filename )
65
66
66
67
# copy file from downloaded_file to target_path
67
68
shutil .copyfile (downloaded_file , target_path )
@@ -111,14 +112,13 @@ def load_model(
111
112
The loaded XGBoost model ready for making predictions.
112
113
"""
113
114
# Use the project directory
114
- download_dir = "/home/runner/work/Open-Source-Quartz-Solar-Forecast/Open-Source-Quartz-Solar-Forecast"
115
- zipfile_model = os .path .join (download_dir , model_file + ".zip" )
115
+ zipfile_model = os .path .join (self .download_dir , model_file + ".zip" )
116
116
117
117
if not os .path .isfile (zipfile_model ):
118
118
logger .info ("Downloading model..." )
119
119
zipfile_model = self ._download_model (model_file + ".zip" , repo_id , file_path )
120
120
121
- model_path = os .path .join (download_dir , model_file )
121
+ model_path = os .path .join (self . download_dir , model_file )
122
122
if not os .path .isfile (model_path ):
123
123
logger .info ("Preparing model..." )
124
124
self ._decompress_zipfile (zipfile_model )
0 commit comments