You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
title={Computation through neural population dynamics},
3
+
author={Vyas, Saurabh and Golub, Matthew D and Sussillo, David and Shenoy, Krishna V},
4
+
journal={Annual review of neuroscience},
5
+
volume={43},
6
+
number={1},
7
+
pages={249--275},
8
+
year={2020},
9
+
publisher={Annual Reviews},
10
+
doi={10.1146/annurev-neuro-092619-094115}
11
+
}
12
+
13
+
@book{murphy2023probabilistic,
14
+
author = "Kevin P. Murphy",
15
+
title = "Probabilistic Machine Learning: Advanced Topics",
16
+
publisher = "MIT Press",
17
+
year = 2023,
18
+
url = "http://probml.github.io/book2"
19
+
}
20
+
21
+
@book{sarkka2023bayesian,
22
+
title={Bayesian filtering and smoothing},
23
+
author={S{\"a}rkk{\"a}, Simo and Svensson, Lennart},
24
+
volume={17},
25
+
year={2023},
26
+
publisher={Cambridge University Press},
27
+
doi={10.1017/CBO9781139344203}
28
+
}
29
+
30
+
31
+
@misc{jax,
32
+
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
33
+
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
author={Lee, Hyun Dong and Warrington, Andrew and Glaser, Joshua and Linderman, Scott},
52
+
journal={Advances in Neural Information Processing Systems},
53
+
volume={36},
54
+
pages={57976--58010},
55
+
year={2023},
56
+
doi={10.48550/arXiv.2306.03291}
57
+
}
58
+
59
+
@inproceedings{chang2023low,
60
+
title = {Low-rank extended {K}alman filtering for online learning of neural networks from streaming data},
61
+
author = {Chang, Peter G. and Dur\'an-Mart\'in, Gerardo and Shestopaloff, Alex and Jones, Matt and Murphy, Kevin P},
62
+
booktitle = {Proceedings of The 2nd Conference on Lifelong Learning Agents},
63
+
pages = {1025--1071},
64
+
year = {2023},
65
+
editor = {Chandar, Sarath and Pascanu, Razvan and Sedghi, Hanie and Precup, Doina},
66
+
volume = {232},
67
+
series = {Proceedings of Machine Learning Research},
68
+
month = {22--25 Aug},
69
+
publisher = {PMLR},
70
+
doi={10.48550/arXiv.2305.19535},
71
+
}
72
+
73
+
74
+
@article{weinreb2024keypoint,
75
+
author = {Weinreb, Caleb and Pearl, Jonah E. and Lin, Sherry and Osman, Mohammed Abdal Monium and Zhang, Libby and Annapragada, Sidharth and Conlin, Eli and Hoffmann, Red and Makowska, Sofia and Gillis, Winthrop F. and Jay, Maya and Ye, Shaokai and Mathis, Alexander and Mathis, Mackenzie W. and Pereira, Talmo and Linderman, Scott W. and Datta, Sandeep Robert},
76
+
date = {2024/07/01},
77
+
id = {Weinreb2024},
78
+
journal = {Nature Methods},
79
+
number = {7},
80
+
pages = {1329--1339},
81
+
title = {Keypoint-{M}o{S}eq: parsing behavior by linking point tracking to pose dynamics},
82
+
volume = {21},
83
+
year = {2024},
84
+
doi={10.1038/s41592-024-02318-2},
85
+
}
86
+
87
+
@misc{pyhsmm,
88
+
author = {Matthew James Johnson},
89
+
title = {{PyHSMM}: Bayesian inference in HSMMs and HMMs},
author = {Duran-Martin, Gerardo and Murphy, Kevin and Kara, Aleyna},
114
+
title = {{JSL: JAX State-Space models (SSM) Library}},
115
+
url={https://github.com/probml/JSL},
116
+
year={2022}
117
+
}
118
+
119
+
@inproceedings{seabold2010statsmodels,
120
+
title={statsmodels: {E}conometric and statistical modeling with python},
121
+
author={Seabold, Skipper and Perktold, Josef},
122
+
booktitle={9th Python in Science Conference},
123
+
year={2010},
124
+
doi={10.25080/majora-92bf1922-011}
125
+
}
126
+
127
+
@misc{hmmlearn,
128
+
author={Ron Weiss and Shiqiao Du and Jaques Grobler and David Cournapeau and Fabian Pedregosa and Gael Varoquaux and Andreas Mueller and Bertrand Thirion and Daniel Nouri and Gilles Louppe and Jake Vanderplas and John Benediktsson and Lars Buitinck and Mikhail Korobov and Robert McGibbon and Stefano Lattarini and Vlad Niculae and Alexandre Gramfort and Sergei Lebedev and Daniela Huppenkothen and Christopher Farrow and Alexandr Yanenko and Antony Lee and Matthew Danielson and Alex Rockhill},
title = "Biological sequence analysis: {P}robabilistic models of proteins
137
+
and nucleic acids",
138
+
author = "Durbin, Richard and Eddy, Sean R and Krogh, Anders and Mitchison,
139
+
Graeme",
140
+
publisher = "Cambridge University Press",
141
+
month = apr,
142
+
year = 1998,
143
+
doi={10.1017/cbo9780511790492},
144
+
}
145
+
146
+
@article{patterson2008state,
147
+
title={State-space models of individual animal movement},
148
+
author={Patterson, Toby A and Thomas, Len and Wilcox, Chris and Ovaskainen, Otso and Matthiopoulos, Jason},
149
+
journal={Trends in ecology \& evolution},
150
+
volume={23},
151
+
number={2},
152
+
pages={87--94},
153
+
year={2008},
154
+
publisher={Elsevier},
155
+
doi={10.1016/j.tree.2007.10.009}
156
+
}
157
+
158
+
@article{jacquier2002bayesian,
159
+
title={Bayesian analysis of stochastic volatility models},
160
+
author={Jacquier, Eric and Polson, Nicholas G and Rossi, Peter E},
161
+
journal={Journal of Business \& Economic Statistics},
162
+
volume={20},
163
+
number={1},
164
+
pages={69--87},
165
+
year={2002},
166
+
publisher={Taylor \& Francis},
167
+
doi={10.1198/073500102753410408}
168
+
}
169
+
170
+
@article{ott2004local,
171
+
title={A local ensemble {K}alman filter for atmospheric data assimilation},
172
+
author={Ott, Edward and Hunt, Brian R and Szunyogh, Istvan and Zimin, Aleksey V and Kostelich, Eric J and Corazza, Matteo and Kalnay, Eugenia and Patil, DJ and Yorke, James A},
173
+
journal={Tellus A: Dynamic Meteorology and Oceanography},
174
+
volume={56},
175
+
number={5},
176
+
pages={415--428},
177
+
year={2004},
178
+
publisher={Taylor \& Francis},
179
+
doi={10.3402/tellusa.v56i5.14462}
180
+
}
181
+
182
+
@article{stone1975parallel,
183
+
title={Parallel tridiagonal equation solvers},
184
+
author={Stone, Harold S},
185
+
journal={ACM Transactions on Mathematical Software (TOMS)},
186
+
volume={1},
187
+
number={4},
188
+
pages={289--307},
189
+
year={1975},
190
+
publisher={ACM New York, NY, USA},
191
+
doi={10.1145/355656.355657}
192
+
}
193
+
194
+
@article{sarkka2020temporal,
195
+
title={Temporal parallelization of {B}ayesian smoothers},
196
+
author={S{\"a}rkk{\"a}, Simo and Garc{\'\i}a-Fern{\'a}ndez, {\'A}ngel F},
197
+
journal={IEEE Transactions on Automatic Control},
198
+
volume={66},
199
+
number={1},
200
+
pages={299--306},
201
+
year={2020},
202
+
publisher={IEEE},
203
+
doi={10.1109/TAC.2020.2976316}
204
+
}
205
+
206
+
@article{hassan2021temporal,
207
+
title={Temporal parallelization of inference in hidden {M}arkov models},
208
+
author={Hassan, Syeda Sakira and S{\"a}rkk{\"a}, Simo and Garc{\'\i}a-Fern{\'a}ndez, {\'A}ngel F},
title: 'Dynamax: A Python package for probabilistic state space modeling with JAX'
3
+
tags:
4
+
- Python
5
+
- State space models
6
+
- dynamics
7
+
- JAX
8
+
9
+
authors:
10
+
- name: Scott W. Linderman
11
+
orcid: 0000-0002-3878-9073
12
+
affiliation: "1"# (Multiple affiliations must be quoted)
13
+
corresponding: true
14
+
- name: Peter Chang
15
+
affiliation: "2"
16
+
- name: Giles Harper-Donnelly
17
+
affiliation: "3"
18
+
- name: Aleyna Kara
19
+
affiliation: "4"
20
+
- name: Xinglong Li
21
+
affiliation: "5"
22
+
- name: Gerardo Duran-Martin
23
+
affiliation: "6"
24
+
- name: Kevin Murphy
25
+
affiliation: "7"
26
+
corresponding: true
27
+
affiliations:
28
+
- name: Department of Statistics and Wu Tsai Neurosciences Institute, Stanford University, USA
29
+
index: 1
30
+
- name: CSAIL, Massachusetts Institute of Technology, USA
31
+
index: 2
32
+
- name: Cambridge University, England, UK
33
+
index: 3
34
+
- name: Computer Science Department, Technical University of Munich Garching, Germany
35
+
index: 4
36
+
- name: Statistics Department, University of British Columbia, Canada
37
+
index: 5
38
+
- name: Queen Mary University of London, England, UK
39
+
index: 6
40
+
- name: Google DeepMind, USA
41
+
index: 7
42
+
43
+
date: 19 July 2024
44
+
bibliography: paper.bib
45
+
46
+
---
47
+
48
+
# Summary
49
+
50
+
State space models (SSMs) are fundamental tools for modeling sequential data. They are broadly used across engineering disciplines like signal processing and control theory, as well as scientific domains like neuroscience [@vyas2020computation], genetics [@durbin1998biological], ecology [@patterson2008state], computational ethology [@weinreb2024keypoint], economics [@jacquier2002bayesian], and climate science [@ott2004local]. Fast and robust tools for state space modeling are crucial to researchers in all of these application areas.
51
+
52
+
State space models specify a probability distribution over a sequence of observations, $y_1, \ldots y_T$, where $y_t$ denotes the observation at time $t$. The key assumption of an SSM is that the observations arise from a sequence of _latent states_, $z_1, \ldots, z_T$, which evolve according to a _dynamics model_ (aka transition model). An SSM may also use inputs (aka controls or covariates), $u_1,\ldots,u_T$, to steer the latent state dynamics and influence the observations.
53
+
For example, in a neuroscience application from @vyas2020computation, $y_t$ represents a vector of spike counts from $\sim 1000$ measured neurons, and $z_t$ is a lower dimensional latent state that changes slowly over time and captures correlations among the measured neurons. If sensory inputs to the neural circuit are known, they can be encoded in $u_t$.
54
+
In the computational ethology application of @weinreb2024keypoint, $y_t$ represents a vector of 3D locations for several key points on an animal's body, and $z_t$ is a discrete behavioral state that specifies how the animal's posture changes over time.
55
+
In both examples, there are two main objectives: First, we aim to infer the latent states $z_t$ that best explain the observed data; formally, this is called _state inference_.
56
+
Second, we need to estimate the dynamics that govern how latent states evolve; formally, this is part of the _parameter estimation_ process.
57
+
`Dynamax` provides algorithms for state inference and parameter estimation in a variety of SSMs.
58
+
59
+
There are a few key design choices to make when constructing an SSM:
60
+
61
+
- What is the type of latent state? E.g., is $z_t$ a continuous or discrete random variable?
62
+
- How do the latent states evolve over time? E.g., are the dynamics linear or nonlinear?
63
+
- How are the observations distributed? E.g., are they Gaussian, Poisson, etc.?
64
+
65
+
Some design choices are so common they have their own names. Hidden Markov models (HMM) are SSMs with discrete latent states, and linear dynamical systems (LDS) are SSMs with continuous latent states, linear dynamics, and additive Gaussian noise. `Dynamax` supports canonical SSMs and allows the user to construct bespoke models as needed, simply by inheriting from a base class and specifying a few model-specific functions. For example, see the _Creating Custom HMMs_ tutorial in the Dynamax documentation.
66
+
67
+
Finally, even for canonical models, there are several algorithms for state inference and parameter estimation. `Dynamax` provides robust implementations of several low-level inference algorithms to suit a variety of applications, allowing users to choose among a host of models and algorithms for their application. More information about state space models and algorithms for state inference and parameter estimation can be found in the textbooks by @murphy2023probabilistic and @sarkka2023bayesian.
68
+
69
+
70
+
# Statement of need
71
+
72
+
`Dynamax` is an open-source Python package for state space modeling. Since it is built with `JAX` [@jax], it supports just-in-time (JIT) compilation for hardware acceleration on CPU, GPU, and TPU machines. It also supports automatic differentiation for gradient-based model learning. While other libraries exist for state space modeling in Python [@pyhsmm; @ssm; @eeasensors; @seabold2010statsmodels; @hmmlearn] and Julia [@dalle2024hiddenmarkovmodels], `Dynamax` provides a diverse combination of low-level inference algorithms and high-level modeling objects that can support a wide range of research applications in JAX. Additionally, `Dynamax` implements parallel message passing algorithms that leverage the associative scan (a.k.a., parallel scan) primitive in JAX to take full advantage of modern hardware accelerators. Currently, these primitives are not natively supported in other frameworks like PyTorch. While various subsets of these models and algorithms may be found in other libraries, Dynamax is a "one stop shop" for state space modeling in JAX.
73
+
74
+
The API for `Dynamax` is divided into two parts: a set of core, functionally pure, low-level inference algorithms, and a high-level, object oriented module for constructing and fitting probabilistic SSMs. The low-level inference API provides message passing algorithms for several common types of SSMs. For example, `Dynamax` provides `JAX` implementations for:
75
+
76
+
- Forward-Backward algorithms for discrete-state hidden Markov models (HMMs),
77
+
- Kalman filtering and smoothing algorithms for linear Gaussian SSMs,
78
+
- Extended and unscented generalized Kalman filtering and smoothing for nonlinear and/or non-Gaussian SSMs, and
79
+
- Parallel message passing routines that leverage GPU or TPU acceleration to perform message passing in $O(\log T)$ time on a parallel machine [@stone1975parallel; @sarkka2020temporal; @hassan2021temporal]. Note that these routines are not simply parallelizing over batches of time series, but rather using a parallel algorithm with sublinear depth or span.
80
+
81
+
The high-level model API makes it easy to construct, fit, and inspect HMMs and linear Gaussian SSMs. Finally, the online `Dynamax` documentation and tutorials provide a wealth of resources for state space modeling experts and newcomers alike.
82
+
83
+
`Dynamax` has supported several publications. The low-level API has been used in machine learning research [@zhao2023revisiting; @lee2023switching; @chang2023low]. Special purpose libraries have been built on top of `Dynamax`, like the Keypoint-MoSeq library for modeling animal behavior [@weinreb2024keypoint] and the Structural Time Series in JAX library, `sts-jax`[@sts-jax]. Finally, the `Dynamax` tutorials are used as reference examples in a major machine learning textbook [@murphy2023probabilistic].
84
+
85
+
# Acknowledgements
86
+
87
+
A significant portion of this library was developed while S.W.L. was a Visiting Faculty Researcher at Google and P.C., G.H.D., A.K., and X.L. were Google Summer of Code participants.
0 commit comments