您的位置:首页 > 其它

机器学习-解决类型数量不平衡

2018-02-01 15:26 197 查看
from sklearn.svm import SVC
from sklearn.metrics import classification_report
from sklearn.cross_validation import train_test_split
import numpy as np
import matplotlib.pyplot as plt

#加载输入文件中的多变量数据
def load_data(input_file):
X = []
y = []
with open(input_file, 'r') as f:
for line in f.readlines():
data = [float(x) for x in line.split(',')]
X.append(data[:-1])
y.append(data[-1])
X = np.array(X)
y = np.array(y)
return X, y
#作图函数
def plot_classifier(classifier, X, y, title='Classifier boundaries', annotate=False):
# 定义绘图范围
x_min, x_max = min(X[:, 0]) - 1.0, max(X[:, 0]) + 1.0
y_min, y_max = min(X[:, 1]) - 1.0, max(X[:, 1]) + 1.0
# 定义步数大小
step_size = 0.01
# 定义计算网格
x_values, y_values = np.meshgrid(np.arange(x_min, x_max, step_size), np.arange(y_min, y_max, step_size))
# 计算分类器输出
mesh_output = classifier.predict(np.c_[x_values.ravel(), y_values.ravel()])
# 重构数组形状
mesh_output = mesh_output.reshape(x_values.shape)

plt.figure()
plt.title(title)
plt.pcolormesh(x_values, y_values, mesh_output, cmap=plt.cm.gray)
plt.scatter(X[:, 0], X[:, 1], c=y, s=60, edgecolors='blue', linewidth=1, cmap=plt.cm.Paired)
plt.xlim(x_values.min(), x_values.max())
plt.ylim(y_values.min(), y_values.max())
plt.xticks(())
plt.yticks(())

if annotate:
for x, y in zip(X[:, 0], X[:, 1]):
plt.annotate(
'(' + str(round(x, 1)) + ',' + str(round(y, 1)) + ')',
xy = (x, y), xytext = (-15, 15),
textcoords = 'offset points',
horizontalalignment = 'right',
verticalalignment = 'bottom',
bbox = dict(boxstyle = 'round,pad=0.6', fc = 'white', alpha = 0.8),
arrowprops = dict(arrowstyle = '-', connectionstyle = 'arc3,rad=0'))
return None
#加载数据
input_file = 'data_multivar_imbalance.txt'
X, y = load_data(input_file)

#划分数据集,训练集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=33)
param = {'kernel':'linear'}
classifier = SVC(**param)
classifier.fit(X, y)
plot_classifier(classifier, X_train, y_train, 'Training dataset')   #没有边界线
print(classification_report(y, classifier.predict(X), target_names=['class-0', 'class-1']))#class-0的准确性为0
plt.show()

#参数class_weight的作用是统计不同类型数据点的数量,调整权重,让类型不平衡问题不影响分类效果
param = {'kernel':'linear', 'class_weight':'balanced'}
classifier = SVC(**param)
classifier.fit(X, y)
plot_classifier(classifier, X_train, y_train, 'Training dataset')   #没有边界线
print(classification_report(y, classifier.predict(X), target_names=['class-0', 'class-1']))#class-0的准确性为0
plt.show()






内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: