반응형
/*******************************************************************************************************************
-- Title : [Py3.5] K-Means Clustering & Axis3D plot Example w/ Scitkit-Learn
-- Reference : scikit-learn.org
-- Key word : k-means clustering k-평균 클러스터링 scitkit-learn sklearn axis3d iris 군집화
*******************************************************************************************************************/
■ Figures
■ Scripts
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 | # -*- coding: utf-8 -*- import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D from sklearn.cluster import KMeans from sklearn import datasets # Code source: Gaël Varoquaux # Modified for documentation by Jaques Grobler # License: BSD 3 clause # ------------------------------ # -- Get iris data # ------------------------------ np.random.seed(5) centers = [[1, 1], [-1, -1], [1, -1]] iris = datasets.load_iris() print(iris['feature_names']) print("... feature_names", "." * 100, "\n") X = iris.data y = iris.target print("** X(iris.data) \n", X[:10]) print("** y(iris.target) \n", y) print("... X, y", "." * 100, "\n") # ------------------------------ # -- Initiate # ------------------------------ dic_estimators = {'k_means_iris_3': KMeans(n_clusters=3), 'k_means_iris_8': KMeans(n_clusters=8), 'k_means_iris_bad_init': KMeans(n_clusters=3, n_init=1, init='random') } print(dic_estimators) print(",,, dic_estimators", "," * 100, "\n") # ------------------------------ # -- Draw plot # ------------------------------ fignum = 1 # -- # -- Figure 1,2,3 # -- for name, est in dic_estimators.items(): fig = plt.figure(fignum, figsize=(7,6)) plt.clf() # clear figure ax = Axes3D(fig, rect=[0, 0, .95, 1], elev=48, azim=134) plt.cla() # clear axis est.fit(X) labels = est.labels_ ax.scatter(X[:, 3], X[:, 0], X[:, 2], c=labels.astype(np.float)) ax.w_xaxis.set_ticklabels([]) ax.w_yaxis.set_ticklabels([]) ax.w_zaxis.set_ticklabels([]) ax.set_xlabel('Petal width') ax.set_ylabel('Sepal length') ax.set_zlabel('Petal length') fignum = fignum + 1 plt.show() # -- # -- Figure 4 # -- # Plot the ground truth fig = plt.figure(fignum, figsize=(7,6)) plt.clf() ax = Axes3D(fig, rect=[0, 0, .95, 1], elev=48, azim=134) plt.cla() for name, label in [('Setosa', 0), ('Versicolour', 1), ('Virginica', 2)]: ax.text3D(X[y == label, 3].mean(), X[y == label, 0].mean() + 1.5, X[y == label, 2].mean(), name, horizontalalignment='center', bbox=dict(alpha=.5, edgecolor='w', facecolor='w')) # Reorder the labels to have colors matching the cluster results y = np.choose(y, [1, 2, 0]).astype(np.float) ax.scatter(X[:, 3], X[:, 0], X[:, 2], c=y) ax.w_xaxis.set_ticklabels([]) ax.w_yaxis.set_ticklabels([]) ax.w_zaxis.set_ticklabels([]) ax.set_xlabel('Petal width') ax.set_ylabel('Sepal length') ax.set_zlabel('Petal length') plt.show() | cs |
반응형