Skip to content

Commit be90200

Browse files
save the model in a sensible place (#227)
1 parent 810750b commit be90200

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

quartz_solar_forecast/forecasts/v2.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from xgboost.sklearn import XGBRegressor
1212

1313
from . import constants
14+
import quartz_solar_forecast
1415

1516
logger = logging.getLogger(__name__)
1617

@@ -36,6 +37,7 @@ class TryolabsSolarPowerPredictor:
3637
Predicts solar power output for the given parameters.
3738
"""
3839
DATE_COLUMN = "date"
40+
download_dir = os.path.dirname(quartz_solar_forecast.__file__) + "/models"
3941

4042
def _download_model(self, filename: str, repo_id: str, file_path: str) -> str:
4143
"""
@@ -56,12 +58,11 @@ def _download_model(self, filename: str, repo_id: str, file_path: str) -> str:
5658
The path to the locally saved model file.
5759
"""
5860
# Use the project directory instead of the user's home directory
59-
download_dir = "/home/runner/work/Open-Source-Quartz-Solar-Forecast/Open-Source-Quartz-Solar-Forecast"
60-
os.makedirs(download_dir, exist_ok=True)
61+
os.makedirs(self.download_dir, exist_ok=True)
6162

62-
downloaded_file = hf_hub_download(repo_id=repo_id, filename=file_path, cache_dir=download_dir)
63+
downloaded_file = hf_hub_download(repo_id=repo_id, filename=file_path, cache_dir=self.download_dir)
6364

64-
target_path = os.path.join(download_dir, filename)
65+
target_path = os.path.join(self.download_dir, filename)
6566

6667
# copy file from downloaded_file to target_path
6768
shutil.copyfile(downloaded_file, target_path)
@@ -111,14 +112,13 @@ def load_model(
111112
The loaded XGBoost model ready for making predictions.
112113
"""
113114
# Use the project directory
114-
download_dir = "/home/runner/work/Open-Source-Quartz-Solar-Forecast/Open-Source-Quartz-Solar-Forecast"
115-
zipfile_model = os.path.join(download_dir, model_file + ".zip")
115+
zipfile_model = os.path.join(self.download_dir, model_file + ".zip")
116116

117117
if not os.path.isfile(zipfile_model):
118118
logger.info("Downloading model...")
119119
zipfile_model = self._download_model(model_file + ".zip", repo_id, file_path)
120120

121-
model_path = os.path.join(download_dir, model_file)
121+
model_path = os.path.join(self.download_dir, model_file)
122122
if not os.path.isfile(model_path):
123123
logger.info("Preparing model...")
124124
self._decompress_zipfile(zipfile_model)

tests/test_forecast_no_ts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def test_run_forecast_no_ts():
1111
current_hr = pd.Timestamp.now().round(freq='h')
1212

1313
# run gradient boosting model with no ts
14-
predications_df = run_forecast(site=site, model="gb")
14+
predications_df = run_forecast(site=site, model="gb", ts=current_ts)
1515
# check current ts agrees with dataset
1616
assert predications_df.index.min() == current_ts
1717

@@ -20,7 +20,7 @@ def test_run_forecast_no_ts():
2020
print(f"Max: {predications_df['power_kw'].max()}")
2121

2222
# run xgb model with no ts
23-
predications_df = run_forecast(site=site, model="xgb")
23+
predications_df = run_forecast(site=site, model="xgb", ts=current_ts)
2424
# check current ts agrees with dataset
2525
assert predications_df.index.min() == current_hr
2626

0 commit comments

Comments
 (0)