本文共 10867 字,大约阅读时间需要 36 分钟。
上一章的最后我们提到:
训练集上表现得不好,那么测试集上的效果肯定也不好,即$color{red}{欠拟合}$ 训练集上表现得好,测试集表现得不好,即$color{red}{过拟合}$1、增加数据集中的观测值 (样本数量)
2、增加数据的维度 (特征数量)如果拿不到新数据的情况下(样本数量不变),只能通过增加数据维度的方法来解决欠拟合的问题,那用什么方式来增加维度呢?
假设现在有两个特征:x1、x2
我们做二阶多项式扩展,最终的结果是:x1、x2、x12、x1x2 、x22;我们人为得在数据中增加了这些特征:x12、x1x2 、x22;比如新特征:x1x2(长*宽)=面积。思考:如果包含三个特征值,x1、x2、x3,他们的二阶多项式扩展会以怎样的形式展开?
x1x2、x1x3、x2x3、x12、x22、x32;二阶多项式展开,特征之间的最高次方为2,如果做三阶多项式展开,特征之间的最高次方为3,但也包含也之前二阶多项式展开中的所有项。
假设现在有两个特征:x1、x2x1、x2、x12、x1x2 、x22;
x12x2、x1x22、x13、x23
$color{red}{维度爆炸:}$
我们发现,仅仅两个特征做三阶多项式扩展,就会产生9个新的维度。可以预测到的是:随着最初特征数量的增加,多项式扩展增加的维度的上升速度会很快。如果一开始特征就非常多的情况下,进行多项式扩展的阶数又过高的情况下,会导致数据集的增长维度过快,计算机的运算量会非常巨大。我们称为维度爆炸,或维度灾难(cures of dimension)在Sklearn当中有三大模型:Transformer 转换器、Estimator 估计器、Pipeline 管道
## 数据标准化## StandardScaler 画图纸ss = StandardScaler() ## fit_transform训练并转换 ## fit在计算,transform完成输出X_train = ss.fit_transform(X_train) X_train
Transformer有输入有输出,同时输出可以放入Transformer或者Estimator 当中作为输入。
## 模型训练lr = LinearRegression()## LinearRegression 是一个有监督的算法,所以要把特征值和目标值一起放入lr.fit(X_train,Y_train) #训练模型## 模型校验y_predict = lr.predict(X_test) #预测结果
y_predict 是估计器的输出模型,估计器输出无法再放入Transformer 或 Estimator当中再获取另一个输出了。
将Transformer、Estimator 组合起来成为一个大模型。
管道: 输入→□→□→□→■→ 输出□:Transformer ; ■:Estimator ;Transformer放在管道前几个模型中,而Estimator 只能放到管道的最后一个模型中。结合:《》
头文件引入Pipeline:
from sklearn.pipeline import Pipeline
其他需要引入的包:
##家庭用电预测:线性回归算法(时间与功率&功率与电流之间的关系)## 一般用到sklearn的子库import sklearnfrom sklearn.model_selection import train_test_split #训练集测试集划分,最新版本中该库直接归到了sklearn的子库from sklearn.linear_model import LinearRegression # 线性模型from sklearn.preprocessing import StandardScaler # 预处理的库from sklearn.preprocessing import MinMaxScaler## 管道相关的包from sklearn.pipeline import Pipelinefrom sklearn.preprocessing import PolynomialFeaturesfrom sklearn.model_selection import GridSearchCV## 再提一下标准化的概念:## StandardScaler作用:去均值和方差归一化## 假如对于某个特征中的一列数据集,x1,x2, ... ,xn## 标准化后的数据: (x1-均值)/标准差,(x2-均值)/标准差, ... ,(xn-均值)/标准差import numpy as npimport pandas as pdimport matplotlib as mplimport matplotlib.pyplot as pltimport time## 设置字符集,防止中文乱码mpl.rcParams['font.sans-serif'] = ['simHei']mpl.rcParams['axes.unicode_minus'] = False
创建一个关于时间的格式化字符串函数
时间 16/12/2006 用 %d/%m/%Y %H:%M:%S 格式化的方法处理数据def data_fromat(dt): t = time.strptime(' '.join(dt),'%d/%m/%Y %H:%M:%S') return (t.tm_year,t.tm_mon,t.tm_mday,t.tm_hour,t.tm_min,t.tm_sec)
Pipeline的参数是一个列表,列表中存放着每一个模型的信息。
第0个模型名字:ss,告诉系统我要做数据标准化
第1个模型名字:Poly,告诉系统我要做一个多项式扩展。
PolynomialFeatures即进行了ss= StandardScaler()的操作,并做了3阶的扩展第2个模型名字:Linear,告诉系统进行模型训练。fit_intercept=False 表示截距为0
截距:y=ax+b, b是截距。一般推荐使用fit_intercept=True。如果输入特征包含x1,x2,将特征放入多项式扩展的图纸后,我们会得到一个针对x1,x2扩展的特征集,并把数据输出出来。因此在多项式扩展的算法中,存储的特征集合将是扩展后的结果。
## 时间和电压之间的关系 (Linear - 多项式)models = [ Pipeline([ ('ss',StandardScaler()), ('Poly',PolynomialFeatures(degree=3)),#给定多项式扩展操作-3阶扩展 ('Linear',LinearRegression(fit_intercept=False)) ])]model = models[0]
获取数据,在上一章中有详细介绍,本章不再赘述。
## 数据文件的路径path1 = 'C:\\Users\\Gorde\\Desktop\\household_power_consumption\\household_power_consumption_100.txt'#如果没有混合类型的数据时,可以通过low_memory=False来调用更多的内存,提高读取速度df = pd.read_csv(path1,sep=';',low_memory=False) ## 处理异常数据new_df = df.replace('?',np.nan) #替换非法字符为nan空## how='any' 遇到空值就删掉; axis=0 删除行;data = new_df.dropna(axis=0,how='any') ## 日期、时间、有功功率、无功功率、电压、电流、厨房用电功率、洗衣服用电功率、热水器用电功率names2=df.columnsnames=['Date', 'Time', 'Global_active_power', 'Global_reactive_power', 'Voltage', 'Global_intensity', 'Sub_metering_1', 'Sub_metering_2', 'Sub_metering_3']## 获取特征X和目标YX = data[names[0:2]]X = X.apply(lambda x:pd.Series(data_fromat(x)),axis=1)Y= data[names[4]]
重点:设置1~5阶多项式扩展,看哪个模型的最终拟合度最好
# 对数据集进行测试集合训练集划分X_train,X_test,Y_train,Y_test = train_test_split(X, Y, test_size=0.2, random_state=0)## 数据标准化,将这一步放到了管道中## 当然也可以删除管道的第0行,放开下面的注释,## 得到的结果是一样的#ss = StandardScaler()#X_train = ss.fit_transform(X_train) # 训练并转换#X_test = ss.transform(X_test) ## 直接使用在模型构建数据上进行一个数据标准化操作 # 模型训练t=np.arange(len(X_test))##设置1~5阶多项式扩展,看哪个模型的最终拟合度最好N = 5d_pool = np.arange(1,N,1) # 阶m = d_pool.sizeclrs = [] # 颜色for c in np.linspace(16711680, 255, m): clrs.append('#%06x' % int(c))line_width = 3plt.figure(figsize=(12,6), facecolor='w')#创建一个绘图窗口,设置大小,设置颜色for i,d in enumerate(d_pool): plt.subplot(N-1,1,i+1) plt.plot(t, Y_test, 'r-', label=u'真实值', ms=10, zorder=N) ## 对Poly模型中的degree参数进行赋值 ## 模型名__参数 model.set_params(Poly__degree=d) ## 设置多项式的阶乘 ## fit完后流转到下一个节点 ## 虽然标准化数据的时候只针对X_train,但因为后面进入Estimator环节需要Y_train的数据,所以一并传入 model.fit(X_train, Y_train) ## model.get_params()调用管道中所有模型的参数 ## ['Linear'] 提取Linear模型的参数 lin = model.get_params()['Linear'] output = u'%d阶,系数为:' % d ## 判断Linear模型中是否有alpha这个参数 if hasattr(lin, 'alpha_'): idx = output.find(u'系数') output = output[:idx] + (u'alpha=%.6f, ' % lin.alpha_) + output[idx:] ## 判断Linear模型中是否有l1_ratio这个参数 if hasattr(lin, 'l1_ratio_'): idx = output.find(u'系数') output = output[:idx] + (u'l1_ratio=%.6f, ' % lin.l1_ratio_) + output[idx:] ## 输出Linear模型中θ1~θn的属性 print ('==',output, lin.coef_.ravel()) y_hat = model.predict(X_test) s = model.score(X_test, Y_test) z = N - 1 if (d == 2) else 0 label = u'%d阶, 准确率=%.3f' % (d,s) plt.plot(t, y_hat, color=clrs[i], lw=line_width, alpha=0.75, label=label, zorder=z) plt.legend(loc = 'upper left') plt.grid(True) plt.ylabel(u'%d阶结果' % d, fontsize=12)## 预测值和实际值画图比较plt.suptitle(u"线性回归预测时间和功率之间的多项式关系", fontsize=20)plt.grid(b=True)plt.show()
结果:
== 1阶,系数为: [ 234.635375 0. 0. 0. -0.50376467 -0.43146823 0. ]== 2阶,系数为: [ 2.35318068e+02 -6.57252031e-14 1.11022302e-16 -6.55031585e-15 -5.71168427e-01 -3.76554668e-01 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 -9.24737046e-01 -8.06322387e-01 0.00000000e+00 -1.74179681e-01 0.00000000e+00 0.00000000e+00]== 3阶,系数为: [ 2.13129561e+11 1.32727630e+12 -1.30480175e+12 -5.84592052e+12 -3.54666596e+11 2.83229781e-01 1.67040304e+12 -4.22518953e+12 4.85924486e+12 6.79144658e+12 -3.32463273e+12 -1.08242302e+13 -1.09288546e+13 -1.76420169e+12 -5.09819286e+11 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 -2.07077200e+11 -1.26342773e-01 0.00000000e+00 1.85546875e-02 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 1.37336540e+11 -2.27539062e-01 0.00000000e+00 -8.84399414e-01 0.00000000e+00 0.00000000e+00 -5.73242188e-01 0.00000000e+00 0.00000000e+00 0.00000000e+00]== 4阶,系数为: [ 1.90130586e+02 3.83693077e-12 -6.94910796e-12 -4.78195261e-12 7.02006845e+01 3.40996949e+00 -1.49213975e-12 1.37134748e-12 -8.68638494e-13 -1.34292577e-12 -3.03351788e-12 4.48041604e-12 -1.94155803e-12 -1.24167343e-12 3.81561449e-12 -7.18181070e-12 -2.13162821e-13 2.57571742e-14 -6.57252031e-14 -1.87627691e-13 -2.75335310e-13 1.62536651e-13 5.57818267e+01 6.41779285e+00 -1.41220369e-13 7.68221007e+00 -1.68753900e-13 9.88098492e-14 4.91273688e-14 -8.52860415e-14 -5.38647519e-13 3.97555610e-13 -3.78653235e-29 5.04870979e-29 -2.52435490e-29 -5.04870979e-29 5.04870979e-29 0.00000000e+00 -1.00974196e-28 0.00000000e+00 -7.57306469e-29 1.39250583e-29 2.53951075e-29 -2.83803115e-29 6.57415717e-29 6.42382859e-30 -1.01636592e-28 9.72188732e-29 -2.80259693e-44 1.12103877e-44 -2.38220739e-44 -9.80908925e-45 -3.36311631e-44 -2.80259693e-45 -2.24207754e-44 1.68155816e-44 -2.24207754e-44 -1.01291903e-45 -3.04107162e-45 8.26065381e-46 1.71351712e-44 3.62881753e-45 2.07883572e-44 2.59181725e-44 4.97841222e-60 0.00000000e+00 0.00000000e+00 0.00000000e+00 -2.48920611e-60 1.24460306e-60 7.46761833e-60 3.11150764e-61 0.00000000e+00 -1.10542958e-75 -2.96604674e+01 -1.81198079e+01 4.79403659e-94 -1.11690774e+01 0.00000000e+00 0.00000000e+00 -5.14703828e-01 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 -9.61067605e+00 -1.60393413e+01 0.00000000e+00 -1.25463427e+01 0.00000000e+00 0.00000000e+00 1.29012532e-01 0.00000000e+00 0.00000000e+00 0.00000000e+00 2.89961813e-01 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]由于数据比较少才100条,所以模型的准确度提高得不是很多。但总体来上还是在进步的:
转载地址:http://kfrgl.baihongyu.com/