Note that there are some explanatory texts on larger screens.

plurals
  1. PO
    primarykey
    data
    text
    <p>I present below a sample <a href="http://en.wikipedia.org/wiki/Silhouette_%28clustering%29" rel="noreferrer">silhouette</a> implementation in both MATLAB and Python/Numpy (keep in mind that I am more fluent in MATLAB):</p> <h1>1) MATLAB</h1> <pre class="lang-matlab prettyprint-override"><code>function s = mySilhouette(X, IDX) %# X : matrix of size N-by-p, data where rows are instances %# IDX: vector of size N, cluster index of each instance (starting from 1) %# s : vector of size N, silhouette score value of each instance N = size(X,1); %# number of instances K = numel(unique(IDX)); %# number of clusters %# compute pairwise distance matrix D = squareform( pdist(X,'euclidean').^2 ); %# indices belonging to each cluster kIndices = accumarray(IDX, 1:N, [K 1], @(x){sort(x)}); %# compute a,b,s for each instance %# a(i): average distance from i to all other data within the same cluster. %# b(i): lowest average dist from i to the data of another single cluster a = zeros(N,1); b = zeros(N,1); for i=1:N ind = kIndices{IDX(i)}; ind = ind(ind~=i); a(i) = mean( D(i,ind) ); b(i) = min( cellfun(@(ind) mean(D(i,ind)), kIndices([1:K]~=IDX(i))) ); end s = (b-a) ./ max(a,b); end </code></pre> <p>To emulate the plot from the <a href="http://www.mathworks.com/help/stats/silhouette.html" rel="noreferrer">silhouette</a> function in MATLAB, we group the silhouette values by cluster, sort within each, then plot the bars horizontally. MATLAB adds <code>NaN</code>s to separate the bars from the different clusters, I found it easier to simply color-code the bars:</p> <pre class="lang-matlab prettyprint-override"><code>%# sample data load fisheriris X = meas; N = size(X,1); %# cluster and compute silhouette score K = 3; [IDX,C] = kmeans(X, K, 'distance','sqEuclidean'); s = mySilhouette(X, IDX); %# plot [~,ord] = sortrows([IDX s],[1 -2]); indices = accumarray(IDX(ord), 1:N, [K 1], @(x){sort(x)}); ytick = cellfun(@(ind) (min(ind)+max(ind))/2, indices); ytickLabels = num2str((1:K)','%d'); %#' h = barh(1:N, s(ord),'hist'); set(h, 'EdgeColor','none', 'CData',IDX(ord)) set(gca, 'CLim',[1 K], 'CLimMode','manual') set(gca, 'YDir','reverse', 'YTick',ytick, 'YTickLabel',ytickLabels) xlabel('Silhouette Value'), ylabel('Cluster') %# compare against SILHOUETTE figure, silhouette(X,IDX) </code></pre> <p><img src="https://i.stack.imgur.com/nB4hv.png" alt="mySilhouette"> <img src="https://i.stack.imgur.com/bHCc0.png" alt="silhouette"></p> <hr> <h1>2) Python</h1> <p>And here is what I came up with in Python:</p> <pre class="lang-py prettyprint-override"><code>import numpy as np from scipy.cluster.vq import kmeans2 from scipy.spatial.distance import pdist, squareform from sklearn import datasets import matplotlib.pyplot as plt from matplotlib import cm def silhouette(X, cIDX): """ Computes the silhouette score for each instance of a clustered dataset, which is defined as: s(i) = (b(i)-a(i)) / max{a(i),b(i)} with: -1 &lt;= s(i) &lt;= 1 Args: X : A M-by-N array of M observations in N dimensions cIDX : array of len M containing cluster indices (starting from zero) Returns: s : silhouette value of each observation """ N = X.shape[0] # number of instances K = len(np.unique(cIDX)) # number of clusters # compute pairwise distance matrix D = squareform(pdist(X)) # indices belonging to each cluster kIndices = [np.flatnonzero(cIDX==k) for k in range(K)] # compute a,b,s for each instance a = np.zeros(N) b = np.zeros(N) for i in range(N): # instances in same cluster other than instance itself a[i] = np.mean( [D[i][ind] for ind in kIndices[cIDX[i]] if ind!=i] ) # instances in other clusters, one cluster at a time b[i] = np.min( [np.mean(D[i][ind]) for k,ind in enumerate(kIndices) if cIDX[i]!=k] ) s = (b-a)/np.maximum(a,b) return s def main(): # load Iris dataset data = datasets.load_iris() X = data['data'] # cluster and compute silhouette score K = 3 C, cIDX = kmeans2(X, K) s = silhouette(X, cIDX) # plot order = np.lexsort((-s,cIDX)) indices = [np.flatnonzero(cIDX[order]==k) for k in range(K)] ytick = [(np.max(ind)+np.min(ind))/2 for ind in indices] ytickLabels = ["%d" % x for x in range(K)] cmap = cm.jet( np.linspace(0,1,K) ).tolist() clr = [cmap[i] for i in cIDX[order]] fig = plt.figure() ax = fig.add_subplot(111) ax.barh(range(X.shape[0]), s[order], height=1.0, edgecolor='none', color=clr) ax.set_ylim(ax.get_ylim()[::-1]) plt.yticks(ytick, ytickLabels) plt.xlabel('Silhouette Value') plt.ylabel('Cluster') plt.show() if __name__ == '__main__': main() </code></pre> <p><img src="https://i.stack.imgur.com/YIhL8.png" alt="python_mySilhouette"></p> <hr> <h2>Update:</h2> <p>As noted by others, scikit-learn has since then added its own <a href="http://scikit-learn.org/stable/modules/generated/sklearn.metrics.silhouette_samples.html" rel="noreferrer">silhouette metric</a> <a href="https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/cluster/unsupervised.py" rel="noreferrer">implementation</a>. To use it in the above code, replace the call to the custom-defined <code>silhouette</code> function with:</p> <pre class="lang-py prettyprint-override"><code>from sklearn.metrics import silhouette_samples ... #s = silhouette(X, cIDX) s = silhouette_samples(X, cIDX) # &lt;-- scikit-learn function ... </code></pre> <p>the rest of the code can still be used as-is to generate the exact same plot.</p>
    singulars
    1. This table or related slice is empty.
    plurals
    1. This table or related slice is empty.
    1. This table or related slice is empty.
    1. This table or related slice is empty.
    1. VO
      singulars
      1. This table or related slice is empty.
    2. VO
      singulars
      1. This table or related slice is empty.
    3. VO
      singulars
      1. This table or related slice is empty.
 

Querying!

 
Guidance

SQuiL has stopped working due to an internal error.

If you are curious you may find further information in the browser console, which is accessible through the devtools (F12).

Reload