Skip to content

Commit 444899f

Browse files
update
1 parent 1ef966e commit 444899f

File tree

3 files changed

+88
-85
lines changed

3 files changed

+88
-85
lines changed

src/main.py

Lines changed: 44 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99
from src.markov.transitions import build_transition_matrices
1010
from src.preprocessing.loader import load_timeseries
1111

12-
SIM_DAYS = 3
12+
SIM_DAYS = 10
13+
PER_DAY = 96
1314

1415

1516
def _detect_value_col(df: pd.DataFrame) -> str:
16-
candidate_cols = ["x", "value", "load", "power", "p_norm", "load_norm"]
17-
for c in candidate_cols:
17+
for c in ["x", "value", "load", "power", "p_norm", "load_norm"]:
1818
if c in df.columns and np.issubdtype(df[c].dtype, np.number):
1919
return c
20-
raise KeyError("No numeric load column found – please inspect the dataframe.")
20+
raise KeyError("numeric load column missing")
2121

2222

2323
def _simulate_series(
@@ -28,54 +28,25 @@ def _simulate_series(
2828
periods: int,
2929
rng: np.random.Generator | None = None,
3030
) -> pd.DataFrame:
31-
"""Generate synthetic 15‑min series (timestamp, state, x)."""
3231
rng = np.random.default_rng() if rng is None else rng
33-
timestamps = pd.date_range(start_ts, periods=periods, freq="15min")
32+
ts = pd.date_range(start_ts, periods=periods, freq="15min")
3433
states = np.empty(periods, dtype=int)
3534
xs = np.empty(periods, dtype=float)
3635

3736
s = start_state
38-
for i, ts in enumerate(timestamps):
39-
b = bucket_id(ts)
37+
for i, t in enumerate(ts):
38+
b = bucket_id(t)
4039
s = rng.choice(probs.shape[1], p=probs[b, s])
4140
states[i] = s
4241
xs[i] = sample_value(gmms, b, s, rng=rng)
43-
return pd.DataFrame({"timestamp": timestamps, "state": states, "x_sim": xs})
4442

45-
46-
def main() -> None:
47-
df = load_timeseries(normalize=True, discretize=True)
48-
if "bucket" not in df.columns:
49-
df = assign_buckets(df)
50-
51-
value_col = _detect_value_col(df)
52-
print("Using load column:", value_col)
53-
54-
counts = build_transition_counts(df)
55-
probs = build_transition_matrices(df)
56-
57-
_plot_first_25_buckets(counts, probs)
58-
59-
print("Fitting GMMs … (this may take a moment)")
60-
gmms = fit_gmms(df, value_col=value_col)
61-
62-
periods = SIM_DAYS * 96
63-
sim_df = _simulate_series(
64-
probs,
65-
gmms,
66-
start_ts=df["timestamp"].min().normalize(),
67-
start_state=int(df["state"].iloc[0]),
68-
periods=periods,
69-
)
70-
71-
_plot_simulation_diagnostics(df, sim_df, value_col)
43+
return pd.DataFrame({"timestamp": ts, "state": states, "x_sim": xs})
7244

7345

7446
def _plot_first_25_buckets(counts: np.ndarray, probs: np.ndarray) -> None:
75-
"""Heat‑map grid for buckets 0‑24."""
76-
buckets = list(range(25))
47+
buckets = range(25)
7748
fig, axes = plt.subplots(5, 5, figsize=(15, 15), sharex=True, sharey=True)
78-
vmax = probs[buckets].max()
49+
vmax = probs[list(buckets)].max()
7950
norm = Normalize(vmin=0, vmax=vmax)
8051

8152
for idx, b in enumerate(buckets):
@@ -92,15 +63,15 @@ def _plot_first_25_buckets(counts: np.ndarray, probs: np.ndarray) -> None:
9263
ax.axis("off")
9364

9465
fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.6, label="p")
95-
fig.suptitle("Transition probabilities – buckets 024", fontsize=14)
96-
fig.tight_layout(rect=[0, 0, 0.97, 0.96])
66+
fig.suptitle("Transition probabilities – buckets 024", fontsize=14)
67+
plt.tight_layout(rect=[0, 0, 0.97, 0.96])
9768
plt.show()
9869

9970

10071
def _plot_simulation_diagnostics(
10172
df: pd.DataFrame, sim: pd.DataFrame, value_col: str
10273
) -> None:
103-
first_day = sim.iloc[:96]
74+
first_day = sim.iloc[:PER_DAY]
10475
plt.figure(figsize=(10, 3))
10576
plt.plot(first_day["timestamp"], first_day["x_sim"], marker=".")
10677
plt.title("Simulated power – first day")
@@ -121,13 +92,43 @@ def _plot_simulation_diagnostics(
12192
sim["hour"] = sim["timestamp"].dt.hour
12293
plt.figure(figsize=(10, 4))
12394
sim.boxplot(column="x_sim", by="hour", grid=False)
124-
plt.suptitle("")
12595
plt.title("Simulated power by hour of day")
12696
plt.xlabel("hour of day")
12797
plt.ylabel("normalised load x")
12898
plt.tight_layout()
12999
plt.show()
130100

131101

102+
def main() -> None:
103+
df = load_timeseries(normalize=True, discretize=True)
104+
if "bucket" not in df.columns:
105+
df = assign_buckets(df)
106+
107+
val_col = _detect_value_col(df)
108+
109+
counts = build_transition_counts(df)
110+
probs = build_transition_matrices(df)
111+
112+
_plot_first_25_buckets(counts, probs)
113+
114+
gmms = fit_gmms(
115+
df,
116+
value_col=val_col,
117+
verbose=1,
118+
heartbeat_seconds=60,
119+
)
120+
121+
periods = SIM_DAYS * PER_DAY
122+
sim = _simulate_series(
123+
probs,
124+
gmms,
125+
start_ts=df["timestamp"].min().normalize(),
126+
start_state=int(df["state"].iloc[0]),
127+
periods=periods,
128+
)
129+
130+
_plot_simulation_diagnostics(df, sim, val_col)
131+
132+
132133
if __name__ == "__main__":
133134
main()

src/markov/gmm.py

Lines changed: 44 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,25 @@
11
import math
2-
from typing import Dict, List, Tuple
2+
from typing import List, Optional, Tuple
33

44
import joblib
55
import numpy as np
66
import pandas as pd
77
from sklearn.mixture import GaussianMixture
88

99
try:
10-
# rich progress bar for terminals / notebooks; optional dependency
11-
from tqdm.auto import tqdm # type: ignore
10+
from tqdm.auto import tqdm
1211
except ImportError:
13-
tqdm = None # type: ignore
12+
tqdm = None
1413

1514
from .buckets import NUM_BUCKETS
1615
from .transition_counts import N_STATES
1716

18-
__all__ = [
19-
"GaussianBucketModels",
20-
"fit_gmms",
21-
"sample_value",
22-
]
23-
17+
__all__ = ["GaussianBucketModels", "fit_gmms", "sample_value"]
2418

2519
GmmTuple = Tuple[np.ndarray, np.ndarray, np.ndarray]
26-
GaussianBucketModels = List[List[GmmTuple]]
20+
GaussianBucketModels = List[List[Optional[GmmTuple]]]
21+
22+
_rng = np.random.default_rng()
2723

2824

2925
def _fit_single(
@@ -32,31 +28,35 @@ def _fit_single(
3228
min_samples: int = 30,
3329
k_candidates: Tuple[int, ...] = (1, 2, 3),
3430
random_state: int | None = None,
35-
) -> GmmTuple:
36-
"""Fit 1‑3 component spherical GMM; fallback to Normal if too few samples."""
31+
) -> Optional[GmmTuple]:
32+
if x.size == 0:
33+
return None
3734
if len(x) < min_samples:
3835
mean = float(np.mean(x))
39-
var = float(np.var(x) + 1e-6)
40-
return (np.array([1.0]), np.array([mean]), np.array([var]))
41-
36+
var = float(max(np.var(x), 1e-5))
37+
return (
38+
np.array([1.0], dtype=float),
39+
np.array([mean], dtype=float),
40+
np.array([var], dtype=float),
41+
)
4242
best_bic = math.inf
4343
best_gmm: GaussianMixture | None = None
4444
for k in k_candidates:
4545
gmm = GaussianMixture(
4646
n_components=k,
4747
covariance_type="spherical",
48-
n_init="auto",
48+
n_init=1,
4949
random_state=random_state,
5050
).fit(x.reshape(-1, 1))
5151
bic = gmm.bic(x.reshape(-1, 1))
5252
if bic < best_bic:
5353
best_bic = bic
5454
best_gmm = gmm
55-
56-
weights = best_gmm.weights_
57-
means = best_gmm.means_.ravel()
58-
variances = best_gmm.covariances_.ravel()
59-
return (weights, means, variances)
55+
return (
56+
best_gmm.weights_,
57+
best_gmm.means_.ravel(),
58+
best_gmm.covariances_.ravel(),
59+
)
6060

6161

6262
def fit_gmms(
@@ -65,29 +65,37 @@ def fit_gmms(
6565
value_col: str = "x",
6666
bucket_col: str = "bucket",
6767
state_col: str = "state",
68-
min_samples: int = 30,
68+
min_samples: int = 5,
6969
k_candidates: Tuple[int, ...] = (1, 2, 3),
7070
n_jobs: int = -1,
7171
random_state: int | None = None,
7272
verbose: int = 0,
73+
heartbeat_seconds: int | None = None,
7374
) -> GaussianBucketModels:
74-
"""Return list [bucket][state] -> (weights, means, variances)."""
75+
if heartbeat_seconds:
76+
import faulthandler
77+
78+
faulthandler.enable()
79+
faulthandler.dump_traceback_later(heartbeat_seconds, repeat=True)
7580

76-
samples: Dict[Tuple[int, int], List[float]] = {}
77-
for _, row in df[[bucket_col, state_col, value_col]].iterrows():
78-
samples.setdefault((row[bucket_col], row[state_col]), []).append(row[value_col])
81+
grouped = (
82+
df[[bucket_col, state_col, value_col]]
83+
.groupby([bucket_col, state_col])[value_col]
84+
.apply(list)
85+
.to_dict()
86+
)
7987

8088
tasks = [
81-
((b, s), samples.get((b, s), []))
89+
((b, s), grouped.get((b, s), []))
8290
for b in range(NUM_BUCKETS)
8391
for s in range(N_STATES)
8492
]
8593

86-
iterable = tasks
87-
if verbose > 0 and tqdm is not None:
88-
iterable = tqdm(tasks, desc="Fitting GMMs", unit="model")
94+
iterable = (
95+
tqdm(tasks, desc="Fitting GMMs", unit="model") if verbose and tqdm else tasks
96+
)
8997

90-
def _worker(item: Tuple[Tuple[int, int], List[float]]):
98+
def _worker(item):
9199
(b, s), x_list = item
92100
x = np.asarray(x_list, dtype=float)
93101
return (
@@ -113,17 +121,16 @@ def _worker(item: Tuple[Tuple[int, int], List[float]]):
113121
return gmms
114122

115123

116-
_rng = np.random.default_rng()
117-
118-
119124
def sample_value(
120125
gmms: GaussianBucketModels,
121126
bucket: int,
122127
state: int,
123128
rng: np.random.Generator | None = None,
124129
) -> float:
125-
"""Draw a normalised load value from the GMM for (bucket, state)."""
126-
weights, means, vars_ = gmms[bucket][state]
130+
gmm = gmms[bucket][state]
131+
if gmm is None:
132+
raise ValueError(f"No GMM trained for bucket {bucket}, state {state}")
133+
weights, means, vars_ = gmm
127134
rng = _rng if rng is None else rng
128135
comp = rng.choice(len(weights), p=weights)
129136
return float(rng.normal(means[comp], math.sqrt(vars_[comp])))

src/markov/transition_counts.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,6 @@ def build_transition_counts(
1515
bucket_col: str = "bucket",
1616
dtype=np.uint32,
1717
) -> np.ndarray:
18-
"""
19-
Absolute transition counts:
20-
C[b, i, j] = # of times state_t=i → state_{t+1}=j in bucket b
21-
Shape = (2 304, 10, 10).
22-
"""
2318
df = df.sort_values("timestamp")
2419

2520
s_t = df[state_col].to_numpy(dtype=int)[:-1]

0 commit comments

Comments
 (0)