Skip to content

Commit fb96220

Browse files
committed
restore GPT5 code
1 parent c66c8a6 commit fb96220

File tree

1 file changed

+66
-30
lines changed

1 file changed

+66
-30
lines changed

pymc_extras/distributions/discrete.py

Lines changed: 66 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,12 @@ class GrassiaIIGeometricRV(RandomVariable):
412412
@classmethod
413413
def rng_fn(cls, rng, r, alpha, time_covariate_vector, size):
414414
# Aggregate time covariates for each sample before broadcasting
415-
exp_time_covar = np.exp(time_covariate_vector).sum(axis=0)
415+
time_cov = np.asarray(time_covariate_vector)
416+
if np.ndim(time_cov) == 0:
417+
exp_time_covar = np.asarray(1.0)
418+
else:
419+
# Collapse all time/feature axes to a scalar multiplier for RNG
420+
exp_time_covar = np.asarray(np.exp(time_cov).sum())
416421

417422
# Determine output size
418423
if size is None:
@@ -428,6 +433,11 @@ def rng_fn(cls, rng, r, alpha, time_covariate_vector, size):
428433
lam_covar = lam * exp_time_covar
429434

430435
p = 1 - np.exp(-lam_covar)
436+
# TODO: This is a hack to ensure valid probability in (0, 1]
437+
# We should find a better way to do this.
438+
# Ensure valid probability in (0, 1]
439+
tiny = np.finfo(p.dtype).tiny
440+
p = np.clip(p, tiny, 1.0)
431441
samples = rng.geometric(p)
432442
# samples = np.ceil(np.log(1 - rng.uniform(size=size)) / (-lam_covar))
433443

@@ -500,24 +510,29 @@ def dist(cls, r, alpha, time_covariate_vector=None, *args, **kwargs):
500510

501511
if time_covariate_vector is None:
502512
time_covariate_vector = pt.constant(0.0)
513+
time_covariate_vector = pt.as_tensor_variable(time_covariate_vector)
514+
# Normalize covariate to be 1D over time
515+
if time_covariate_vector.ndim == 0:
516+
time_covariate_vector = pt.reshape(time_covariate_vector, (1,))
517+
elif time_covariate_vector.ndim > 1:
518+
feature_axes = tuple(range(time_covariate_vector.ndim - 1))
519+
time_covariate_vector = pt.sum(time_covariate_vector, axis=feature_axes)
503520

504521
return super().dist([r, alpha, time_covariate_vector], *args, **kwargs)
505522

506523
def logp(value, r, alpha, time_covariate_vector):
507-
logp = pt.log(
508-
pt.pow(alpha / (alpha + C_t(value - 1, time_covariate_vector)), r)
509-
- pt.pow(alpha / (alpha + C_t(value, time_covariate_vector)), r)
510-
)
511-
512-
# Handle invalid values
513-
logp = pt.switch(
514-
pt.or_(
515-
value < 1, # Value must be >= 1
516-
pt.isnan(logp), # Handle NaN cases
517-
),
518-
-np.inf,
519-
logp,
520-
)
524+
v = pt.as_tensor_variable(value)
525+
ct_prev = C_t(v - 1, time_covariate_vector)
526+
ct_curr = C_t(v, time_covariate_vector)
527+
logS_prev = r * (pt.log(alpha) - pt.log(alpha + ct_prev))
528+
logS_curr = r * (pt.log(alpha) - pt.log(alpha + ct_curr))
529+
# Compute log(exp(logS_prev) - exp(logS_curr)) stably
530+
max_logS = pt.maximum(logS_prev, logS_curr)
531+
diff = pt.exp(logS_prev - max_logS) - pt.exp(logS_curr - max_logS)
532+
logp = max_logS + pt.log(diff)
533+
534+
# Handle invalid / out-of-domain values
535+
logp = pt.switch(value < 1, -np.inf, logp)
521536

522537
return check_parameters(
523538
logp,
@@ -527,9 +542,15 @@ def logp(value, r, alpha, time_covariate_vector):
527542
)
528543

529544
def logcdf(value, r, alpha, time_covariate_vector):
530-
logcdf = r * (
531-
pt.log(C_t(value, time_covariate_vector))
532-
- pt.log(alpha + C_t(value, time_covariate_vector))
545+
# Log CDF: log(1 - (alpha / (alpha + C(t)))**r)
546+
t = pt.as_tensor_variable(value)
547+
ct = C_t(t, time_covariate_vector)
548+
logS = r * (pt.log(alpha) - pt.log(alpha + ct))
549+
# Numerically stable log(1 - exp(logS))
550+
logcdf = pt.switch(
551+
pt.lt(logS, np.log(0.5)),
552+
pt.log1p(-pt.exp(logS)),
553+
pt.log(-pt.expm1(logS)),
533554
)
534555

535556
return check_parameters(
@@ -550,7 +571,6 @@ def support_point(rv, size, r, alpha, time_covariate_vector):
550571
When time_covariate_vector is provided, it affects the expected value through
551572
the exponential link function: exp(time_covariate_vector).
552573
"""
553-
554574
base_lambda = r / alpha
555575

556576
# Approximate expected value of geometric distribution
@@ -560,8 +580,11 @@ def support_point(rv, size, r, alpha, time_covariate_vector):
560580
1.0 / (1.0 - pt.exp(-base_lambda)), # Full expression for larger lambda
561581
)
562582

563-
# Apply time covariates if provided
564-
mean = mean * pt.exp(time_covariate_vector.sum(axis=0))
583+
# Apply time covariates if provided: multiply by exp(sum over axis=0)
584+
# This yields a scalar for 1D covariates and a time-length vector for 2D (features x time)
585+
tcv = pt.as_tensor_variable(time_covariate_vector)
586+
if tcv.ndim != 0:
587+
mean = mean * pt.exp(tcv.sum(axis=0))
565588

566589
# Round up to nearest integer and ensure >= 1
567590
mean = pt.maximum(pt.ceil(mean), 1.0)
@@ -575,14 +598,27 @@ def support_point(rv, size, r, alpha, time_covariate_vector):
575598

576599
def C_t(t: pt.TensorVariable, time_covariate_vector: pt.TensorVariable) -> pt.TensorVariable:
577600
"""Utility for processing time-varying covariates in GrassiaIIGeometric distribution."""
601+
# If unspecified (scalar), simply return t
578602
if time_covariate_vector.ndim == 0:
579-
# Reshape time_covariate_vector to length t
580-
return pt.full((t,), time_covariate_vector)
603+
return t
604+
605+
# Sum exp(covariates) across feature axes, keep last axis as time
606+
if time_covariate_vector.ndim == 1:
607+
per_time_sum = pt.exp(time_covariate_vector)
581608
else:
582-
# Ensure t is a valid index
583-
t_idx = pt.maximum(0, t - 1) # Convert to 0-based indexing
584-
# If t_idx exceeds length of time_covariate_vector, use last value
585-
max_idx = pt.shape(time_covariate_vector)[0] - 1
586-
safe_idx = pt.minimum(t_idx, max_idx)
587-
covariate_value = time_covariate_vector[..., safe_idx]
588-
return pt.exp(covariate_value).sum()
609+
# If axis=0 is time and axis>0 are features, sum over features (axis>0)
610+
per_time_sum = pt.sum(pt.exp(time_covariate_vector), axis=0)
611+
612+
# Build cumulative sum up to each t without advanced indexing
613+
time_length = pt.shape(per_time_sum)[0]
614+
# Ensure t is at least 1D int64 for broadcasting
615+
t_vec = pt.cast(t, "int64")
616+
t_vec = pt.shape_padleft(t_vec) if t_vec.ndim == 0 else t_vec
617+
# Create time indices [0, 1, ..., T-1]
618+
time_idx = pt.arange(time_length, dtype="int64")
619+
# Mask where time index < t (exclusive upper bound)
620+
mask = pt.lt(time_idx, pt.shape_padright(t_vec, 1))
621+
# Sum per-time contributions over time axis
622+
base_sum = pt.sum(pt.shape_padleft(per_time_sum) * mask, axis=-1)
623+
# If original t was scalar, return scalar (saturate at last time step)
624+
return pt.squeeze(base_sum)

0 commit comments

Comments
 (0)