Skip to content

Commit 1fe70c0

Browse files
Circle CICircle CI
authored andcommitted
CircleCI update of dev docs (3297).
1 parent 0ec72bf commit 1fe70c0

294 files changed

Lines changed: 104304 additions & 100863 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Binary file not shown.
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# -*- coding: utf-8 -*-
2+
r"""
3+
=====================================
4+
Gaussian Mixture Model OT Barycenters
5+
=====================================
6+
7+
This example illustrates the computation of a barycenter between Gaussian
8+
Mixtures in the sense of GMM-OT [69]. This computation is done using the
9+
fixed-point method for OT barycenters with generic costs [77], for which POT
10+
provides a general solver, and a specific GMM solver. Note that this is a
11+
'free-support' method, implying that the number of components of the barycenter
12+
GMM and their weights are fixed.
13+
14+
The idea behind GMM-OT barycenters is to see the GMMs as discrete measures over
15+
the space of Gaussian distributions :math:`\mathcal{N}` (or equivalently the
16+
Bures-Wasserstein manifold), and to compute barycenters with respect to the
17+
2-Wasserstein distance between measures in :math:`\mathcal{P}(\mathcal{N})`: a
18+
gaussian mixture is a finite combination of Diracs on specific gaussians, and
19+
two mixtures are compared with the 2-Wasserstein distance on this space, where
20+
ground cost the squared Bures distance between gaussians.
21+
22+
[69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space
23+
of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970.
24+
25+
[77] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing
26+
Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016
27+
(2024)
28+
29+
"""
30+
31+
# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
32+
#
33+
# License: MIT License
34+
35+
# sphinx_gallery_thumbnail_number = 1
36+
37+
# %%
38+
# Generate data
39+
import numpy as np
40+
import matplotlib.pyplot as plt
41+
from matplotlib.patches import Ellipse
42+
import ot
43+
from ot.gmm import gmm_barycenter_fixed_point
44+
45+
46+
K = 3 # number of GMMs
47+
d = 2 # dimension
48+
n = 6 # number of components of the desired barycenter
49+
50+
51+
def get_random_gmm(K, d, seed=0, min_cov_eig=1, cov_scale=1e-2):
52+
rng = np.random.RandomState(seed=seed)
53+
means = rng.randn(K, d)
54+
P = rng.randn(K, d, d) * cov_scale
55+
# C[k] = P[k] @ P[k]^T + min_cov_eig * I
56+
covariances = np.einsum("kab,kcb->kac", P, P)
57+
covariances += min_cov_eig * np.array([np.eye(d) for _ in range(K)])
58+
weights = rng.random(K)
59+
weights /= np.sum(weights)
60+
return means, covariances, weights
61+
62+
63+
m_list = [5, 6, 7] # number of components in each GMM
64+
offsets = [np.array([-3, 0]), np.array([2, 0]), np.array([0, 4])]
65+
means_list = [] # list of means for each GMM
66+
covs_list = [] # list of covariances for each GMM
67+
w_list = [] # list of weights for each GMM
68+
69+
# generate GMMs
70+
for k in range(K):
71+
means, covs, b = get_random_gmm(
72+
m_list[k], d, seed=k, min_cov_eig=0.25, cov_scale=0.5
73+
)
74+
means = means / 2 + offsets[k][None, :]
75+
means_list.append(means)
76+
covs_list.append(covs)
77+
w_list.append(b)
78+
79+
# %%
80+
# Compute the barycenter using the fixed-point method
81+
init_means, init_covs, _ = get_random_gmm(n, d, seed=0)
82+
weights = ot.unif(K) # barycenter coefficients
83+
means_bar, covs_bar, log = gmm_barycenter_fixed_point(
84+
means_list,
85+
covs_list,
86+
w_list,
87+
init_means,
88+
init_covs,
89+
weights,
90+
iterations=3,
91+
log=True,
92+
)
93+
94+
95+
# %%
96+
# Define plotting functions
97+
98+
99+
# draw a covariance ellipse
100+
def draw_cov(mu, C, color=None, label=None, nstd=1, alpha=0.5, ax=None):
101+
def eigsorted(cov):
102+
vals, vecs = np.linalg.eigh(cov)
103+
order = vals.argsort()[::-1].copy()
104+
return vals[order], vecs[:, order]
105+
106+
vals, vecs = eigsorted(C)
107+
theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))
108+
w, h = 2 * nstd * np.sqrt(vals)
109+
ell = Ellipse(
110+
xy=(mu[0], mu[1]),
111+
width=w,
112+
height=h,
113+
alpha=alpha,
114+
angle=theta,
115+
facecolor=color,
116+
edgecolor=color,
117+
label=label,
118+
fill=True,
119+
)
120+
if ax is None:
121+
ax = plt.gca()
122+
ax.add_artist(ell)
123+
124+
125+
# draw a gmm as a set of ellipses with weights shown in alpha value
126+
def draw_gmm(ms, Cs, ws, color=None, nstd=0.5, alpha=1, label=None, ax=None):
127+
for k in range(ms.shape[0]):
128+
draw_cov(
129+
ms[k], Cs[k], color, label if k == 0 else None, nstd, alpha * ws[k], ax=ax
130+
)
131+
132+
133+
# %%
134+
# Plot the results
135+
c_list = ["#7ED321", "#4A90E2", "#9013FE", "#F5A623"]
136+
c_bar = "#D0021B"
137+
fig, ax = plt.subplots(figsize=(6, 6))
138+
axis = [-4, 4, -2, 6]
139+
ax.set_title("Fixed Point Barycenter (3 Iterations)", fontsize=16)
140+
for k in range(K):
141+
draw_gmm(means_list[k], covs_list[k], w_list[k], color=c_list[k], ax=ax)
142+
draw_gmm(means_bar, covs_bar, ot.unif(n), color=c_bar, ax=ax)
143+
ax.axis(axis)
144+
ax.axis("off")
145+
146+
# %%

master/_downloads/0dbd57c6090215001a0a712021c577e5/plot_GMMOT_plan.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
},
1616
"outputs": [],
1717
"source": [
18-
"# Author: Eloi Tanguy <eloi.tanguy@u-paris>\n# Remi Flamary <remi.flamary@polytehnique.edu>\n# Julie Delon <julie.delon@math.cnrs.fr>\n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 1\n\nimport numpy as np\nfrom ot.plot import plot1D_mat, rescale_for_imshow_plot\nfrom ot.gmm import gmm_ot_plan_density, gmm_pdf, gmm_ot_apply_map\nimport matplotlib.pyplot as plt"
18+
"# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>\n# Remi Flamary <remi.flamary@polytehnique.edu>\n# Julie Delon <julie.delon@math.cnrs.fr>\n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 1\n\nimport numpy as np\nfrom ot.plot import plot1D_mat, rescale_for_imshow_plot\nfrom ot.gmm import gmm_ot_plan_density, gmm_pdf, gmm_ot_apply_map\nimport matplotlib.pyplot as plt"
1919
]
2020
},
2121
{
Binary file not shown.

master/_downloads/12cf635d7b9aa9f87c0e3bdc36aaa712/plot_SSNB.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
2017.
4242
"""
4343

44-
# Author: Eloi Tanguy <eloi.tanguy@u-paris.fr>
44+
# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
4545
# License: MIT License
4646

4747
# sphinx_gallery_thumbnail_number = 3
Binary file not shown.

master/_downloads/15645a78701cc4e31af4898794deb04d/plot_generalized_free_support_barycenter.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
},
1616
"outputs": [],
1717
"source": [
18-
"# Author: Eloi Tanguy <eloi.tanguy@polytechnique.edu>\n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 2\n\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport matplotlib.pylab as pl\nimport ot\nimport matplotlib.animation as animation"
18+
"# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>\n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 2\n\nimport numpy as np\nimport matplotlib.pyplot as plt\nimport matplotlib.pylab as pl\nimport ot\nimport matplotlib.animation as animation"
1919
]
2020
},
2121
{
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)