使用sklearn库做线性回归拟合
python
背景资料
随着海拔高度的上升,温度越来越低,经过气象专家的研究,在一定的海拔高度范围内,高度和温度呈线性关系。现有一组实测资料,我们需要对这些数据进行处理拟合,获得此线性关系。
解决思路
采用sklearn库中的LinearRegression线性回归类进行拟合。
代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
<span class="c1"># 导入所需的模块</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="kn">from</span> <span class="nn">sklearn.linear_model</span> <span class="k">import</span> <span class="n">LinearRegression</span> <span class="c1"># 已有数据</span> <span class="n">height</span> <span class="o">=</span> <span class="p">[</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">500.0</span><span class="p">,</span> <span class="mf">1000.0</span><span class="p">,</span> <span class="mf">1500.0</span><span class="p">,</span> <span class="mf">2000.0</span><span class="p">,</span> <span class="mf">2500.0</span><span class="p">,</span> <span class="mf">3000.0</span><span class="p">,</span> <span class="mf">3500.0</span><span class="p">,</span> <span class="mf">4000.0</span><span class="p">]</span> <span class="n">temperature</span> <span class="o">=</span> <span class="p">[</span><span class="mf">12.834044009405147</span><span class="p">,</span> <span class="mf">10.190648986884316</span><span class="p">,</span> <span class="mf">5.50022874963469</span><span class="p">,</span> <span class="mf">2.8546651452636795</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.7064882183657739</span><span class="p">,</span> <span class="o">-</span><span class="mf">4.065322810462405</span><span class="p">,</span> <span class="o">-</span><span class="mf">7.1274795772446575</span><span class="p">,</span> <span class="o">-</span><span class="mf">10.058878545913904</span><span class="p">,</span> <span class="o">-</span><span class="mf">13.206465051538661</span><span class="p">]</span> <span class="c1"># 数据处理</span> <span class="c1"># sklearn 拟合输入输出一般都是二维数组,这里将一维转换为二维。</span> <span class="n">height</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">height</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="n">temp</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">temperature</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="c1"># 拟合</span> <span class="n">reg</span> <span class="o">=</span> <span class="n">LinearRegression</span><span class="p">()</span> <span class="n">reg</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">height</span><span class="p">,</span> <span class="n">temp</span><span class="p">)</span> <span class="n">a</span> <span class="o">=</span> <span class="n">reg</span><span class="o">.</span><span class="n">coef_</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="c1"># 系数</span> <span class="n">b</span> <span class="o">=</span> <span class="n">reg</span><span class="o">.</span><span class="n">intercept_</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="c1"># 截距</span> <span class="nb">print</span><span class="p">(</span><span class="s1">'拟合的方程为:Y = </span><span class="si">%.6f</span><span class="s1">X + </span><span class="si">%.6f</span><span class="s1">'</span> <span class="o">%</span> <span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">))</span> <span class="c1"># out: 拟合的方程为:Y = -0.006570X + 12.718507</span> <span class="c1"># 可视化</span> <span class="n">prediction</span> <span class="o">=</span> <span class="n">reg</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">height</span><span class="p">)</span> <span class="c1"># 根据高度,按照拟合的曲线预测温度值</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="s1">'海拔高度~温度关系曲线拟合结果'</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span><span class="mi">8</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">rcParams</span><span class="p">[</span><span class="s1">'font.family'</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="s1">'sans-serif'</span><span class="p">]</span> <span class="c1"># 设置matplotlib 显示中文</span> <span class="n">plt</span><span class="o">.</span><span class="n">rcParams</span><span class="p">[</span><span class="s1">'font.sans-serif'</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="s1">'SimHei'</span><span class="p">]</span> <span class="c1"># 设置matplotlib 显示中文</span> <span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s1">'温度'</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s1">'高度'</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">temp</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s1">'black'</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">prediction</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="s1">'r'</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> |
