/*******************************************************************************************************************
-- Title : [Py3.5] K-Means Clustering & Axis3D plot Example w/ Scitkit-Learn
-- Reference : scikit-learn.org
-- Key word : k-means clustering k-평균 클러스터링 scitkit-learn sklearn axis3d iris 군집화 
*******************************************************************************************************************/

■ Figures



■ Scripts

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# -*- coding: utf-8 -*-
 
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.cluster import KMeans
from sklearn import datasets
 
# Code source: Gaël Varoquaux
# Modified for documentation by Jaques Grobler
# License: BSD 3 clause
 
# ------------------------------
# -- Get iris data
# ------------------------------
np.random.seed(5)
 
centers = [[11], [-1-1], [1-1]]
iris = datasets.load_iris()
 
print(iris['feature_names'])
print("... feature_names""." * 100"\n")
 
= iris.data
= iris.target
 
print("** X(iris.data) \n", X[:10])
print("** y(iris.target) \n", y)
print("... X, y""." * 100"\n")
 
 
# ------------------------------
# -- Initiate
# ------------------------------
dic_estimators = {'k_means_iris_3': KMeans(n_clusters=3),
                  'k_means_iris_8': KMeans(n_clusters=8),
                  'k_means_iris_bad_init': KMeans(n_clusters=3, n_init=1, init='random')
                 }
 
print(dic_estimators)
print(",,, dic_estimators""," * 100"\n")
 
 
# ------------------------------
# -- Draw plot
# ------------------------------
fignum = 1
 
# --
# -- Figure 1,2,3
# --
for name, est in dic_estimators.items():
    fig = plt.figure(fignum, figsize=(7,6))
    plt.clf()  # clear figure
    ax = Axes3D(fig, rect=[00, .951], elev=48, azim=134)
 
    plt.cla()  # clear axis
    est.fit(X)
    labels = est.labels_
 
    ax.scatter(X[:, 3], X[:, 0], X[:, 2], c=labels.astype(np.float))
 
    ax.w_xaxis.set_ticklabels([])
    ax.w_yaxis.set_ticklabels([])
    ax.w_zaxis.set_ticklabels([])
    ax.set_xlabel('Petal width')
    ax.set_ylabel('Sepal length')
    ax.set_zlabel('Petal length')
    fignum = fignum + 1
 
    plt.show()
 
# --
# -- Figure 4
# --
# Plot the ground truth
fig = plt.figure(fignum, figsize=(7,6))
plt.clf()
ax = Axes3D(fig, rect=[00, .951], elev=48, azim=134)
 
plt.cla()
 
for name, label in [('Setosa'0),
                    ('Versicolour'1),
                    ('Virginica'2)]:
    ax.text3D(X[y == label, 3].mean(),
              X[y == label, 0].mean() + 1.5,
              X[y == label, 2].mean(), name,
              horizontalalignment='center',
              bbox=dict(alpha=.5, edgecolor='w', facecolor='w'))
 
# Reorder the labels to have colors matching the cluster results
= np.choose(y, [120]).astype(np.float)
ax.scatter(X[:, 3], X[:, 0], X[:, 2], c=y)
 
ax.w_xaxis.set_ticklabels([])
ax.w_yaxis.set_ticklabels([])
ax.w_zaxis.set_ticklabels([])
ax.set_xlabel('Petal width')
ax.set_ylabel('Sepal length')
ax.set_zlabel('Petal length')
 
plt.show()
cs



저작자 표시 비영리 변경 금지
신고

+ Recent posts