import math, matplotlib.pyplot as plt, operator, torch from functools import partial torch.manual_seed(42) torch.set_printoptions(precision=3, linewidth=140, sci_mode=False) n_clusters=6 n_samples =250 centroids = torch.rand(n_clusters, 2)*70-35 from torch.distributions.multivariate_normal import MultivariateNormal from torch import tensor def sample(m): return MultivariateNormal(m, torch.diag(tensor([5.,5.]))).sample((n_samples,)) slices = [sample(c) for c in centroids] data = torch.cat(slices) data.shape def plot_data(centroids, data, n_samples, ax=None): if ax is None: _,ax = plt.subplots() for i, centroid in enumerate(centroids): samples = data[i*n_samples:(i+1)*n_samples] ax.scatter(samples[:,0], samples[:,1], s=1) ax.plot(*centroid, markersize=10, marker="x", color='k', mew=5) ax.plot(*centroid, markersize=5, marker="x", color='m', mew=2) plot_data(centroids, data, n_samples) midp = data.mean(0) midp plot_data([midp]*6, data, n_samples) def gaussian(d, bw): return torch.exp(-0.5*((d/bw))**2) / (bw*math.sqrt(2*math.pi)) def plot_func(f): x = torch.linspace(0,10,100) plt.plot(x, f(x)) plot_func(partial(gaussian, bw=2.5)) partial def tri(d, i): return (-d+i).clamp_min(0)/i plot_func(partial(tri, i=8)) X = data.clone() x = data[0] x x.shape,X.shape,x[None].shape (x[None]-X)[:8] (x-X)[:8] # rewrite using torch.einsum dist = ((x-X)**2).sum(1).sqrt() dist[:8] weight = gaussian(dist, 2.5) weight weight.shape,X.shape weight[:,None].shape weight[:,None]*X def one_update(X): for i, x in enumerate(X): dist = torch.sqrt(((x-X)**2).sum(1)) # weight = gaussian(dist, 2.5) weight = tri(dist, 8) X[i] = (weight[:,None]*X).sum(0)/weight.sum() def meanshift(data): X = data.clone() for it in range(5): one_update(X) return X %time X=meanshift(data) plot_data(centroids+2, X, n_samples) from matplotlib.animation import FuncAnimation from IPython.display import HTML def do_one(d): if d: one_update(X) ax.clear() plot_data(centroids+2, X, n_samples, ax=ax) # create your own animation X = data.clone() fig,ax = plt.subplots() ani = FuncAnimation(fig, do_one, frames=5, interval=500, repeat=False) plt.close() HTML(ani.to_jshtml()) bs=5 X = data.clone() x = X[:bs] x.shape,X.shape def dist_b(a,b): return (((a[None]-b[:,None])**2).sum(2)).sqrt() dist_b(X, x) dist_b(X, x).shape X[None,:].shape, x[:,None].shape, (X[None,:]-x[:,None]).shape weight = gaussian(dist_b(X, x), 2) weight weight.shape,X.shape weight[...,None].shape, X[None].shape num = (weight[...,None]*X[None]).sum(1) num.shape num torch.einsum('ij,jk->ik', weight, X) weight@X div = weight.sum(1, keepdim=True) div.shape num/div def meanshift(data, bs=500): n = len(data) X = data.clone() for it in range(5): for i in range(0, n, bs): s = slice(i, min(i+bs,n)) weight = gaussian(dist_b(X, X[s]), 2.5) # weight = tri(dist_b(X, X[s]), 8) div = weight.sum(1, keepdim=True) X[s] = weight@X/div return X data = data.cuda() X = meanshift(data).cpu() %timeit -n 5 _=meanshift(data, 1250).cpu() plot_data(centroids+2, X, n_samples)