Last active
May 29, 2017 21:26
-
-
Save probinso/588297654ca1a23c33c3b4b9a93479e8 to your computer and use it in GitHub Desktop.
k-means
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| using Distances | |
| using Distributions | |
| ncol = (matrix) -> size(matrix, 2) | |
| nrow = (matrix) -> size(matrix, 1) | |
| function k_means(M::Matrix, k::Integer, miters::Integer = 100, d_func = euclidean) | |
| # Normalization | |
| M_μ = mean(M, 2) | |
| M_σ = std(M, 2) | |
| X = ((M .- M_μ) ./ M_σ) | |
| # Initialization | |
| D = Normal() | |
| μ = [rand(D) for __=1:nrow(X), _=1:k] | |
| k_means(X, μ, miters, d_func, true)(M) | |
| end | |
| function k_means(X::Matrix, μ::Matrix, miters::Integer = 100, d_func = euclidean, cont::Bool = false) | |
| μ_0 = μ | |
| labels = [] | |
| for iter in 1:miters | |
| # Expectation | |
| dist = [d_func(X[:, i], μ_0[:, j]) for j=1:ncol(μ_0), i=1:ncol(X)] | |
| labels = [indmin(dist[:, i]) for i=1:ncol(X)] | |
| # Maximization | |
| not_μsed = ![idx in labels for idx = 1:ncol(μ_0)] # don't drop unused means | |
| old_μ = μ_0[:, not_μsed] | |
| new_μ = hcat([mean(X[:, labels .== class], 2) for class in unique(labels)]...) | |
| μ_1 = [new_μ old_μ] | |
| if μ_0 == μ_1 || iter > miters | |
| break | |
| end | |
| # Update | |
| μ_0 = μ_1 | |
| end | |
| f = (X) -> Dict(class => X[:, labels .== class] for class in 1:ncol(μ)) | |
| return cont ? f : f(X) | |
| end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment