Skip to content

Feat/docs #11

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@ wheels/
jax_cache/

.DS_Store

site/
slides/
Empty file added docs/dev.md
Empty file.
16 changes: 16 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
## What is GenJAXMix?
GenJAXMix is Dirichlet Process Mixture Modeling framework written in the JAX-accelerated probabilistic programming language `GenJAX`. Thus it provides a fast implementation to model and infer clusters using DPMMs on the GPU using clean *Genial* programs.

## Quickstart
Currently GenJAXMix is currently *private*, so please install it directly from Github:

```
pip install git+git@github.com:OpenGen/genjaxmix.git
```

## For Developers
GenJAXMix uses `uv` for package management and development. To set up GenJAXMix for development, please install `uv` and run
```
uv sync
```

1 change: 1 addition & 0 deletions docs/reference/api.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: src.genjaxmix.dpmm
1 change: 1 addition & 0 deletions docs/reference/conjugacy.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: src.genjaxmix.conjugacy
1 change: 1 addition & 0 deletions docs/reference/inference.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: src.genjaxmix.smc
266 changes: 266 additions & 0 deletions docs/tutorials/dpmm_ground_up.ipynb

Large diffs are not rendered by default.

58 changes: 58 additions & 0 deletions docs/tutorials/getting_started.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# %% [markdown]
# # Test

# %% [markdown]
# This tutorial introduces Dirichlet Process Mixture Models and explores how to cluster using a simple example.

# %% [markdown]
# # Dataset
# First we need a dataset. In this tutorial, we will synthetically generate points on on the 2D plane and form small clusters that we expect to later detect during inference. We will generate data using a [Gaussian Mixture Model](https://en.wikipedia.org/wiki/Mixture_model#Gaussian_mixture_model) in pure JAX and later feed in this dataset to GenJAXMix.
#

# %%
import jax
import matplotlib.pyplot as plt

key = jax.random.key(0)
K_max = 5
N = 1000

# Generate cluster proportions
betas = jax.random.beta(key, 1.0, 1.0, shape=(K_max, ))
beta_not = jax.lax.cumprod(1 - betas[:-1])
beta_not = jax.numpy.concatenate([jax.numpy.array([1.0]), beta_not])
weights = betas * beta_not

# Generate cluster centers
cluster_centers = jax.random.normal(key, (K_max, 2))


# %% [markdown]
# We follow [the stick-breaking process](https://en.wikipedia.org/wiki/Dirichlet_process#The_stick-breaking_process) description for Dirichlet Process Mixture Models. First we generate the stick lengths which represents the proportion of data points expected to belong to each cluster. We then generate the cluster centers - here we use sample the clusters using normal distributions.

# %%

# Sample cluster assignments
cluster_assignments = jax.random.categorical(key, jax.numpy.log(weights), shape=(N,))

# Generate data points
data_points = cluster_centers[cluster_assignments] + jax.random.normal(key, (N, 2))/5

# Plot the clusters
plt.figure(figsize=(8, 6))
plt.scatter(data_points[:, 0], data_points[:, 1], c=cluster_assignments, cmap='viridis', alpha=0.6)
plt.scatter(cluster_centers[:, 0], cluster_centers[:, 1], c='red', marker='x', s=100, label='Cluster Centers')
plt.title('Generated Data Points and Cluster Centers')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')
plt.legend()
plt.show()

# %% [markdown]
# Here we generated 100 data points for five clusters.

# %% [markdown]
# # Model
# We will now define a Dirichlet Process Mixture Model using GenJAXMix to cluster this dataset.

# %%
Binary file added docs/tutorials/inferred_clusters_animation.mp4
Binary file not shown.
606 changes: 606 additions & 0 deletions docs/tutorials/sub_cluster.ipynb

Large diffs are not rendered by default.

207 changes: 207 additions & 0 deletions examples/clean.ipynb

Large diffs are not rendered by default.

264 changes: 264 additions & 0 deletions examples/vector.ipynb

Large diffs are not rendered by default.

33 changes: 33 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
site_name: GenJAXMix

theme:
name: "material"
# features:
# - navigation.tabs
# - toc.follow
# - toc.integrate

# markdown_extensions:
# - toc

plugins:
- mkdocs-jupyter:
execute: true
- mkdocstrings:
handlers:
python:
options:
show_source: false


nav:
- Home: index.md
- Tutorials:
- Getting Started: tutorials/getting_started.py
- DPMMs from the Ground Up: tutorials/dpmm_ground_up.ipynb
- Splitting Clusters: tutorials/sub_cluster.ipynb
- Library Reference:
- API: reference/api.md
- Conjugacy: reference/conjugacy.md
- Inference: reference/inference.md
- For Developers: dev.md
152 changes: 152 additions & 0 deletions notebooks/split_merge.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Jain Neal Split-Merge Moves"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Random Split-Merge Procedure"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"import matplotlib.pyplot as plt\n",
"from dataclasses import dataclass"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"K_max = 10\n",
"K_true = 3\n",
"D = 2\n",
"N = 10\n",
"alpha = 1.0\n",
"sigma_hyper = 1.0\n",
"\n",
"key = jax.random.key(0)\n",
"\n",
"key, *subkeys = jax.random.split(key, 5)\n",
"\n",
"mu_true = jax.random.normal(subkeys[0], (K_max, D)) * sigma_hyper\n",
"\n",
"pi_true = jax.random.dirichlet(subkeys[1], jnp.ones(K_true)* alpha)\n",
"pi_true = jnp.concat([pi_true, jnp.zeros(K_max-K_true)], axis=0)\n",
"\n",
"z_true = jax.random.categorical(subkeys[2], jnp.log(pi_true), shape=(N,))\n",
"\n",
"data = jax.random.normal(subkeys[3], (N,D)) + mu_true[z_true]"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1 1 0 0 1 1 1 1 0 1]\n",
"A: 3.465735912322998 B: -4.094344615936279 C: 0.0\n",
"0.53333324\n"
]
},
{
"data": {
"text/plain": [
"Array([1, 1, 0, 0, 1, 1, 1, 1, 0, 1], dtype=int32)"
]
},
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@dataclass\n",
"class Latent:\n",
" K_max: int\n",
" alpha: float\n",
" K: int\n",
" z: jax.Array\n",
"\n",
"\n",
"def partition_log_ratio(alpha, M: int, N: int):\n",
" return jnp.log(alpha) - jax.scipy.special.gammaln(M+N) + jax.scipy.special.gammaln(M) + jax.scipy.special.gammaln(N)\n",
"\n",
"def proposal_log_ratio(M,N):\n",
" return jnp.log(2)*(M+N-2)\n",
"\n",
"def likelihood_log_ratio(data, ):\n",
" return 0.0\n",
"\n",
"def random_split(key, data, latent: Latent, i: int, j: int):\n",
" c_next = jnp.max(latent.z)+1\n",
" key, subkey = jax.random.split(key)\n",
" splits = jnp.where(jax.random.bernoulli(subkey, shape=(latent.z.shape[0],)), latent.z[i], c_next)\n",
" mask = latent.z == latent.z[i]\n",
" z_proposal = jnp.where(mask, splits, latent.z)\n",
" z_proposal = z_proposal.at[i].set(latent.z[i])\n",
" z_proposal = z_proposal.at[j].set(c_next)\n",
" \n",
" Ni = jnp.count_nonzero(z_proposal == latent.z[i])\n",
" Nj = jnp.count_nonzero(z_proposal == c_next)\n",
"\n",
" A = proposal_log_ratio(Ni,Nj)\n",
" B = partition_log_ratio(alpha, Ni, Nj)\n",
" C = likelihood_log_ratio(1,1)\n",
" print(f\"A: {A} B: {B} C: {C}\")\n",
"\n",
" a = jnp.exp(min(0, A + B + C))\n",
" print(a)\n",
" key, subkey = jax.random.split(key)\n",
" u = jax.random.uniform(subkey)\n",
" z_new = jnp.where(u < a, z_proposal, latent.z)\n",
" return z_new\n",
"\n",
"\n",
"key = jax.random.key(2)\n",
"latent = Latent(K_max, alpha, 2, jax.random.categorical(key, jnp.ones(2), shape=(N,)))\n",
"print(latent.z)\n",
"random_split(key, data, latent, 0, 1)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
11 changes: 11 additions & 0 deletions src/genjaxmix/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Run Dirichlet Process Mixture Model inference using GenJAX.

Modules exported by this package:

- `dpmm`
"""
# from .dpmm import generate

# __all__ = [
# "generate"
# ]
26 changes: 15 additions & 11 deletions src/genspn/smc.py → src/genjaxmix/_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,12 @@
import numpy as np
from plum import dispatch
from jaxtyping import Array, Integer
from genspn.distributions import (Dirichlet, MixtureModel,
from genjaxmix.distributions import (Dirichlet, MixtureModel,
posterior, sample, logpdf, Normal, Categorical, Mixed, Cluster, Trace)
from functools import partial

@partial(jax.jit, static_argnames=['gibbs_iters', 'max_clusters', 'n_steps'])
def smc(key, trace, data_test, n_steps, data, gibbs_iters, max_clusters):

if not isinstance(data, tuple):
data = (data,)
data_test = (data_test,)
smc_keys = jax.random.split(key, n_steps)

def wrap_step(trace, n):
Expand Down Expand Up @@ -45,7 +41,6 @@ def wrap_step(trace, n):
return trace, sum_logprobs



def rejuvenate(key, data, trace, gibbs_iters, max_clusters):
extended_pi = jnp.concatenate((trace.cluster.pi, jnp.zeros(max_clusters)))
log_likelihood_mask = jnp.where(extended_pi == 0, -jnp.inf, 0)
Expand Down Expand Up @@ -94,6 +89,16 @@ def get_weights(trace, K, data, q_split_trace, max_clusters):

return logpdf_pi + logpdf_split_clusters - logpdf_clusters

def score_q_pi(q_pi, max_clusters, alpha):
q_pi_dist = Dirichlet(alpha=jnp.ones((1, 2)) * alpha/2)
q_pi_stack = Categorical(
jnp.vstack((
jnp.log(q_pi[:max_clusters]),
jnp.log(q_pi[max_clusters:]),
))[None, :])

return jax.vmap(logpdf, in_axes=(None, -1))(q_pi_dist, q_pi_stack)

def make_pi(pi, k, pi_split, max_clusters):
pi_k0 = pi[k]
pi = pi.at[k].set(pi_k0 * pi_split[k])
Expand Down Expand Up @@ -153,11 +158,10 @@ def update_f(f0: Categorical, f: Categorical, k: Integer[Array, ""], K: Integer[
return Categorical(logprobs)

@dispatch
# plum is struggling with this signature for some reason, momentarily using a catch all
# def update_f(f0: Mixed, f: Mixed, k: Integer[Array, ""], K: Integer[Array, ""], max_clusters: Integer[Array, ""]):
def update_f(f0: Mixed, f, k, K, max_clusters):
def update_f(f0: Mixed, f: Mixed, k: Integer[Array, ""], K: Integer[Array, ""], max_clusters: Integer[Array, ""]):
return Mixed(
dists=tuple([update_f(f0.dists[i], f.dists[i], k, K, max_clusters) for i in range(len(f0.dists))])
update_f(f0.normal, f.normal, k, K, max_clusters),
update_f(f0.categorical, f.categorical, k, K, max_clusters)
)

def update_vector(v0, split_v, k, K, max_clusters):
Expand Down Expand Up @@ -224,4 +228,4 @@ def gibbs_pi(max_clusters, key, alpha, c, rejuvenation=False):
pi = jax.random.dirichlet(key, alpha / 2 + cluster_counts)
pi_pairs = pi.reshape((2, -1))
pi_pairs = pi_pairs / jnp.sum(pi_pairs, axis=0)
return pi_pairs.reshape(-1)
return pi_pairs.reshape(-1)
38 changes: 38 additions & 0 deletions src/genjaxmix/conjugacy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import jax.numpy as jnp
import jax
from .dpmm import K, L_num

def posterior_normal_inverse_gamma(assignments:jax.Array, x: jax.Array, mu_0:float=0.0, v_0:float=1.0, a_0:float=1.0, b_0:float=1.0):
"""Compute the posterior parameters of a normal-inverse-gamma distribution given the assignments and data.

See Section 6 of https://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf by Kevin P. Murphy.

Args:
assignments: an array of integers in [0, K) representing the cluster assignments
x: an array of floats representing the data
mu_0: the prior mean
v_0: the prior precision
a_0: the prior shape
b_0: the
"""
counts = jnp.bincount(assignments, length=K)
sum_x = jax.ops.segment_sum(x, assignments, K)
sum_x_sq = jax.ops.segment_sum(x**2, assignments, K)

v_n_inv = 1/v_0 + counts
m = (1/v_0 * mu_0 + sum_x) / v_n_inv
a = a_0 + counts / 2
b = b_0 + 0.5 * (sum_x_sq + 1/v_0*mu_0**2 - v_n_inv * m ** 2)
return m, 1/v_n_inv, a, b

def posterior_dirichlet(assignments:jax.Array, x:jax.Array):
"""Computes the posterior parameters of a Dirichlet distribution for a multinomial likelihood.

Args:
assignments: an array of integers in [0, K) representing the cluster assignments
x: an array of integers in [0, L_num) representing the data
"""
one_hot_c = jax.nn.one_hot(assignments, K)
one_hot_y = jax.nn.one_hot(x, L_num)
frequency_matrix = one_hot_c.T @ one_hot_y
return frequency_matrix
Loading