@@ -412,7 +412,12 @@ class GrassiaIIGeometricRV(RandomVariable):
412
412
@classmethod
413
413
def rng_fn (cls , rng , r , alpha , time_covariate_vector , size ):
414
414
# Aggregate time covariates for each sample before broadcasting
415
- exp_time_covar = np .exp (time_covariate_vector ).sum (axis = 0 )
415
+ time_cov = np .asarray (time_covariate_vector )
416
+ if np .ndim (time_cov ) == 0 :
417
+ exp_time_covar = np .asarray (1.0 )
418
+ else :
419
+ # Collapse all time/feature axes to a scalar multiplier for RNG
420
+ exp_time_covar = np .asarray (np .exp (time_cov ).sum ())
416
421
417
422
# Determine output size
418
423
if size is None :
@@ -428,6 +433,11 @@ def rng_fn(cls, rng, r, alpha, time_covariate_vector, size):
428
433
lam_covar = lam * exp_time_covar
429
434
430
435
p = 1 - np .exp (- lam_covar )
436
+ # TODO: This is a hack to ensure valid probability in (0, 1]
437
+ # We should find a better way to do this.
438
+ # Ensure valid probability in (0, 1]
439
+ tiny = np .finfo (p .dtype ).tiny
440
+ p = np .clip (p , tiny , 1.0 )
431
441
samples = rng .geometric (p )
432
442
# samples = np.ceil(np.log(1 - rng.uniform(size=size)) / (-lam_covar))
433
443
@@ -500,24 +510,29 @@ def dist(cls, r, alpha, time_covariate_vector=None, *args, **kwargs):
500
510
501
511
if time_covariate_vector is None :
502
512
time_covariate_vector = pt .constant (0.0 )
513
+ time_covariate_vector = pt .as_tensor_variable (time_covariate_vector )
514
+ # Normalize covariate to be 1D over time
515
+ if time_covariate_vector .ndim == 0 :
516
+ time_covariate_vector = pt .reshape (time_covariate_vector , (1 ,))
517
+ elif time_covariate_vector .ndim > 1 :
518
+ feature_axes = tuple (range (time_covariate_vector .ndim - 1 ))
519
+ time_covariate_vector = pt .sum (time_covariate_vector , axis = feature_axes )
503
520
504
521
return super ().dist ([r , alpha , time_covariate_vector ], * args , ** kwargs )
505
522
506
523
def logp (value , r , alpha , time_covariate_vector ):
507
- logp = pt .log (
508
- pt .pow (alpha / (alpha + C_t (value - 1 , time_covariate_vector )), r )
509
- - pt .pow (alpha / (alpha + C_t (value , time_covariate_vector )), r )
510
- )
511
-
512
- # Handle invalid values
513
- logp = pt .switch (
514
- pt .or_ (
515
- value < 1 , # Value must be >= 1
516
- pt .isnan (logp ), # Handle NaN cases
517
- ),
518
- - np .inf ,
519
- logp ,
520
- )
524
+ v = pt .as_tensor_variable (value )
525
+ ct_prev = C_t (v - 1 , time_covariate_vector )
526
+ ct_curr = C_t (v , time_covariate_vector )
527
+ logS_prev = r * (pt .log (alpha ) - pt .log (alpha + ct_prev ))
528
+ logS_curr = r * (pt .log (alpha ) - pt .log (alpha + ct_curr ))
529
+ # Compute log(exp(logS_prev) - exp(logS_curr)) stably
530
+ max_logS = pt .maximum (logS_prev , logS_curr )
531
+ diff = pt .exp (logS_prev - max_logS ) - pt .exp (logS_curr - max_logS )
532
+ logp = max_logS + pt .log (diff )
533
+
534
+ # Handle invalid / out-of-domain values
535
+ logp = pt .switch (value < 1 , - np .inf , logp )
521
536
522
537
return check_parameters (
523
538
logp ,
@@ -527,9 +542,15 @@ def logp(value, r, alpha, time_covariate_vector):
527
542
)
528
543
529
544
def logcdf (value , r , alpha , time_covariate_vector ):
530
- logcdf = r * (
531
- pt .log (C_t (value , time_covariate_vector ))
532
- - pt .log (alpha + C_t (value , time_covariate_vector ))
545
+ # Log CDF: log(1 - (alpha / (alpha + C(t)))**r)
546
+ t = pt .as_tensor_variable (value )
547
+ ct = C_t (t , time_covariate_vector )
548
+ logS = r * (pt .log (alpha ) - pt .log (alpha + ct ))
549
+ # Numerically stable log(1 - exp(logS))
550
+ logcdf = pt .switch (
551
+ pt .lt (logS , np .log (0.5 )),
552
+ pt .log1p (- pt .exp (logS )),
553
+ pt .log (- pt .expm1 (logS )),
533
554
)
534
555
535
556
return check_parameters (
@@ -550,7 +571,6 @@ def support_point(rv, size, r, alpha, time_covariate_vector):
550
571
When time_covariate_vector is provided, it affects the expected value through
551
572
the exponential link function: exp(time_covariate_vector).
552
573
"""
553
-
554
574
base_lambda = r / alpha
555
575
556
576
# Approximate expected value of geometric distribution
@@ -560,8 +580,11 @@ def support_point(rv, size, r, alpha, time_covariate_vector):
560
580
1.0 / (1.0 - pt .exp (- base_lambda )), # Full expression for larger lambda
561
581
)
562
582
563
- # Apply time covariates if provided
564
- mean = mean * pt .exp (time_covariate_vector .sum (axis = 0 ))
583
+ # Apply time covariates if provided: multiply by exp(sum over axis=0)
584
+ # This yields a scalar for 1D covariates and a time-length vector for 2D (features x time)
585
+ tcv = pt .as_tensor_variable (time_covariate_vector )
586
+ if tcv .ndim != 0 :
587
+ mean = mean * pt .exp (tcv .sum (axis = 0 ))
565
588
566
589
# Round up to nearest integer and ensure >= 1
567
590
mean = pt .maximum (pt .ceil (mean ), 1.0 )
@@ -575,14 +598,27 @@ def support_point(rv, size, r, alpha, time_covariate_vector):
575
598
576
599
def C_t (t : pt .TensorVariable , time_covariate_vector : pt .TensorVariable ) -> pt .TensorVariable :
577
600
"""Utility for processing time-varying covariates in GrassiaIIGeometric distribution."""
601
+ # If unspecified (scalar), simply return t
578
602
if time_covariate_vector .ndim == 0 :
579
- # Reshape time_covariate_vector to length t
580
- return pt .full ((t ,), time_covariate_vector )
603
+ return t
604
+
605
+ # Sum exp(covariates) across feature axes, keep last axis as time
606
+ if time_covariate_vector .ndim == 1 :
607
+ per_time_sum = pt .exp (time_covariate_vector )
581
608
else :
582
- # Ensure t is a valid index
583
- t_idx = pt .maximum (0 , t - 1 ) # Convert to 0-based indexing
584
- # If t_idx exceeds length of time_covariate_vector, use last value
585
- max_idx = pt .shape (time_covariate_vector )[0 ] - 1
586
- safe_idx = pt .minimum (t_idx , max_idx )
587
- covariate_value = time_covariate_vector [..., safe_idx ]
588
- return pt .exp (covariate_value ).sum ()
609
+ # If axis=0 is time and axis>0 are features, sum over features (axis>0)
610
+ per_time_sum = pt .sum (pt .exp (time_covariate_vector ), axis = 0 )
611
+
612
+ # Build cumulative sum up to each t without advanced indexing
613
+ time_length = pt .shape (per_time_sum )[0 ]
614
+ # Ensure t is at least 1D int64 for broadcasting
615
+ t_vec = pt .cast (t , "int64" )
616
+ t_vec = pt .shape_padleft (t_vec ) if t_vec .ndim == 0 else t_vec
617
+ # Create time indices [0, 1, ..., T-1]
618
+ time_idx = pt .arange (time_length , dtype = "int64" )
619
+ # Mask where time index < t (exclusive upper bound)
620
+ mask = pt .lt (time_idx , pt .shape_padright (t_vec , 1 ))
621
+ # Sum per-time contributions over time axis
622
+ base_sum = pt .sum (pt .shape_padleft (per_time_sum ) * mask , axis = - 1 )
623
+ # If original t was scalar, return scalar (saturate at last time step)
624
+ return pt .squeeze (base_sum )
0 commit comments