机器学习-解决类型数量不平衡
2018-02-01 15:26
197 查看
from sklearn.svm import SVC from sklearn.metrics import classification_report from sklearn.cross_validation import train_test_split import numpy as np import matplotlib.pyplot as plt #加载输入文件中的多变量数据 def load_data(input_file): X = [] y = [] with open(input_file, 'r') as f: for line in f.readlines(): data = [float(x) for x in line.split(',')] X.append(data[:-1]) y.append(data[-1]) X = np.array(X) y = np.array(y) return X, y #作图函数 def plot_classifier(classifier, X, y, title='Classifier boundaries', annotate=False): # 定义绘图范围 x_min, x_max = min(X[:, 0]) - 1.0, max(X[:, 0]) + 1.0 y_min, y_max = min(X[:, 1]) - 1.0, max(X[:, 1]) + 1.0 # 定义步数大小 step_size = 0.01 # 定义计算网格 x_values, y_values = np.meshgrid(np.arange(x_min, x_max, step_size), np.arange(y_min, y_max, step_size)) # 计算分类器输出 mesh_output = classifier.predict(np.c_[x_values.ravel(), y_values.ravel()]) # 重构数组形状 mesh_output = mesh_output.reshape(x_values.shape) plt.figure() plt.title(title) plt.pcolormesh(x_values, y_values, mesh_output, cmap=plt.cm.gray) plt.scatter(X[:, 0], X[:, 1], c=y, s=60, edgecolors='blue', linewidth=1, cmap=plt.cm.Paired) plt.xlim(x_values.min(), x_values.max()) plt.ylim(y_values.min(), y_values.max()) plt.xticks(()) plt.yticks(()) if annotate: for x, y in zip(X[:, 0], X[:, 1]): plt.annotate( '(' + str(round(x, 1)) + ',' + str(round(y, 1)) + ')', xy = (x, y), xytext = (-15, 15), textcoords = 'offset points', horizontalalignment = 'right', verticalalignment = 'bottom', bbox = dict(boxstyle = 'round,pad=0.6', fc = 'white', alpha = 0.8), arrowprops = dict(arrowstyle = '-', connectionstyle = 'arc3,rad=0')) return None #加载数据 input_file = 'data_multivar_imbalance.txt' X, y = load_data(input_file) #划分数据集,训练集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=33) param = {'kernel':'linear'} classifier = SVC(**param) classifier.fit(X, y) plot_classifier(classifier, X_train, y_train, 'Training dataset') #没有边界线 print(classification_report(y, classifier.predict(X), target_names=['class-0', 'class-1']))#class-0的准确性为0 plt.show() #参数class_weight的作用是统计不同类型数据点的数量,调整权重,让类型不平衡问题不影响分类效果 param = {'kernel':'linear', 'class_weight':'balanced'} classifier = SVC(**param) classifier.fit(X, y) plot_classifier(classifier, X_train, y_train, 'Training dataset') #没有边界线 print(classification_report(y, classifier.predict(X), target_names=['class-0', 'class-1']))#class-0的准确性为0 plt.show()
相关文章推荐
- 机器学习可以解决哪些类型的任务?
- 如何解决机器学习中的数据不平衡问题
- 如何解决机器学习中数据不平衡问题
- 如何解决机器学习中的数据不平衡问题
- 如何解决机器学习中数据不平衡问题
- 如何解决机器学习中数据不平衡问题
- 如何解决机器学习中数据不平衡问题
- 如何解决机器学习中数据不平衡问题
- 如何解决机器学习中的数据不平衡问题
- 解决真实世界问题:如何在不平衡类上使用机器学习?
- 如何解决机器学习中数据不平衡问题
- 如何解决机器学习中数据不平衡问题
- 深度 | 解决真实世界问题:如何在不平衡类上使用机器学习?
- 如何解决机器学习中数据不平衡问题
- 机器学习模型构建时正负样本不平衡带来的问题及解决方法
- 机器学习可以解决哪些类型的任务?
- 深度 | 解决真实世界问题:如何在不平衡类上使用机器学习?
- 使用mysql innodb 使用5.7的json类型遇到的坑和解决办法
- 定位点类型和递归部分的类型不匹配,如何解决?
- 解决txt文件中数据带有日期类型导入oracle中到不进去的问题