数据挖掘课程项目

项目名称:Covid-19 全球性预测
成员及分工

  • 杜建成(3220190786) 模型构建,模型相关部分文档撰写
  • 聂宇翔(1120161722) 主要文档撰写,数据、模型间对比分析
  • 陈子康(3220190783) 模型构建,模型相关部分文档撰写
  • 赵菊文(3220191000) 模型构建,模型相关部分文档撰写
  • 窦京伟(3120190314) 数据收集与预处理、可视化以及相关部分文档撰写

数据集:COVID19 Global Forecasting (Week 5)
Github地址https://github.com/ChengchengDu/data_mining_project
:数据集处理代码、数据挖掘代码以及挖掘过程报告已经按照一定的顺序整理在这个notebook中
运行建议:如果要对本课程项目代码进行运行,只需要把该notebook与数据集放在同一个目录下运行即可

内容

0. 问题描述

0.1. 问题背景及分析

0.2. 问题描述

0.2.1. 数据准备

0.2.2. 准备采用的方法或模型

0.2.3. 预期的挖掘结果

0.3 项目评估

1. 数据摘要

1.1. 导库

1.2. 加载数据集

1.3. 检测缺失值

2. 数据可视化

3. 数据清洗

4. 数据格式转换

4.1. 导库

4.2. 数据简要摘要

4.3. 数据信息类型转换

5. 模型构建、训练和预测

5.0 Pinball Loss 函数构建

5.1. RandomForestRegressor模型

5.1.1. 模型构建及训练

5.1.2. 模型预测

5.2. XGBRegressor模型

5.2.1. 模型构建、训练以及预测

5.3. LinearRegression模型

5.3.1. 模型构建、训练以及预测

5.4. 模型间效果的比较分析

5.4.1. MSE评估方法下的模型效果比较

5.4.2. Pinball Loss评估方法下的模型效果比较

0. 问题描述

0.1. 问题背景及分析

白宫科学技术政策办公室(OSTP)召集了一个联盟研究小组和公司(包括Kaggle)来准备COVID-19开放研究数据集(CORD-19),以尝试解决有关COVID-19的关键开放科学问题。这些问题来自美国国家科学,工程和医学研究院(NASEM)和世界卫生组织(WHO)。

Kaggle正在发起COVID-19预测挑战,以帮助回答NASEM / WHO问题的一部分。尽管面临的挑战是为每个地区确定5月12日至6月7日之间确诊病例和死亡人数的分位数估计间隔,但主要目标不仅是产生准确的预测。还可以识别出可能影响COVID-19传输速率的因素。

上述挑战是一个回归分析问题。我们组所采用的数据集是Kaggle中COVID19 Global Forecasting (Week 5)任务的数据集。我们通过构建不同的模型并对比分析不同模型在这些数据集上的优劣的方式,对COVID-19未来在世界范围内的传播趋势进行分析。

0.2 问题描述

0.2.1. 数据准备

采集数据集中的数据,将区域(如全美国)分成适当的n*n(如20*20)的数量级的小数据集,然后在每个小区域集上又根据时间统计各个时间段的确诊病例和死亡人数,最后可以得到根据时间段和区域统计的数据统计,完成初步的数据准备。

0.2.2 准备采用的方法或模型

  1. 使用XGBOOST对训练集进行训练,在测试集上对该地区新冠肺炎确认人数和死亡人数进行预测
  2. 使用RandomForest对训练集进行训练,在测试集上对该地区新冠肺炎确认人数和死亡人数进行预测
  3. 使用logistic regression对训练集进行训练,在测试集上对该地区新冠肺炎确认人数和死亡人数进行预测

0.2.3 预期的挖掘结果

通过所构建的模型实现对未来COVID-19在确诊人数和死亡人数上较为准确的预测。

0.3 项目评估

预测结果主要利用加权Pinball Loss指标,对各个模型在测试数据集上的性能进行评估。加权Pinball Loss的定义如下所示: $$ \text{score}=\frac{1}{N_f}\sum_fw_f\frac{1}{N_\tau}\sum_\tau L_\tau(y_i,\hat{y}_i)$$ 其中, $$L_\tau(y, \hat{y})=(y-\hat{y})\tau\ \ \text{if}\ y\geq \hat{y}\\ =(\hat{y}-y)(1-\tau)\ \ \text{if}\ \hat{y}\gt y$$ 并且有, $y$是真实值

$\hat{y}$是预测值

$\tau$是需要被预测的分位数

$N_f$是总的预测数

$N_\tau$是总的需要被预测的分位数

$w$是权重因子

1. 数据摘要

训练集各列名含义如下:

Field Name Definition
Id id
County
Province_State 省份/州
Country_Region 国家
Population 人口数
Date 时间
Target 确诊/死亡
TargetValue 确诊数/死亡数

1.1. 导库

In [2]:
import pandas as pd
import numpy as np
import itertools
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
import plotly.express as px
import statsmodels.api as sm

import warnings
warnings.simplefilter('ignore')

plt.rcParams['font.family'] = ['sans-serif']
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.style.use('ggplot')

1.2.加载数据集

In [3]:
train=pd.read_csv("./train.csv")
test=pd.read_csv("./test.csv")
sample = pd.read_csv('./submission.csv')
In [4]:
train.head(10)
Out[4]:
Id County Province_State Country_Region Population Weight Date Target TargetValue
0 1 NaN NaN Afghanistan 27657145 0.058359 2020-01-23 ConfirmedCases 0.0
1 2 NaN NaN Afghanistan 27657145 0.583587 2020-01-23 Fatalities 0.0
2 3 NaN NaN Afghanistan 27657145 0.058359 2020-01-24 ConfirmedCases 0.0
3 4 NaN NaN Afghanistan 27657145 0.583587 2020-01-24 Fatalities 0.0
4 5 NaN NaN Afghanistan 27657145 0.058359 2020-01-25 ConfirmedCases 0.0
5 6 NaN NaN Afghanistan 27657145 0.583587 2020-01-25 Fatalities 0.0
6 7 NaN NaN Afghanistan 27657145 0.058359 2020-01-26 ConfirmedCases 0.0
7 8 NaN NaN Afghanistan 27657145 0.583587 2020-01-26 Fatalities 0.0
8 9 NaN NaN Afghanistan 27657145 0.058359 2020-01-27 ConfirmedCases 0.0
9 10 NaN NaN Afghanistan 27657145 0.583587 2020-01-27 Fatalities 0.0
In [5]:
train.shape
Out[5]:
(865750, 9)
In [6]:
test.head(10)
Out[6]:
ForecastId County Province_State Country_Region Population Weight Date Target
0 1 NaN NaN Afghanistan 27657145 0.058359 2020-04-27 ConfirmedCases
1 2 NaN NaN Afghanistan 27657145 0.583587 2020-04-27 Fatalities
2 3 NaN NaN Afghanistan 27657145 0.058359 2020-04-28 ConfirmedCases
3 4 NaN NaN Afghanistan 27657145 0.583587 2020-04-28 Fatalities
4 5 NaN NaN Afghanistan 27657145 0.058359 2020-04-29 ConfirmedCases
5 6 NaN NaN Afghanistan 27657145 0.583587 2020-04-29 Fatalities
6 7 NaN NaN Afghanistan 27657145 0.058359 2020-04-30 ConfirmedCases
7 8 NaN NaN Afghanistan 27657145 0.583587 2020-04-30 Fatalities
8 9 NaN NaN Afghanistan 27657145 0.058359 2020-05-01 ConfirmedCases
9 10 NaN NaN Afghanistan 27657145 0.583587 2020-05-01 Fatalities
In [7]:
test.shape
Out[7]:
(311670, 8)
In [8]:
sample.head(10)
Out[8]:
ForecastId_Quantile TargetValue
0 1_0.05 1
1 1_0.5 1
2 1_0.95 1
3 2_0.05 1
4 2_0.5 1
5 2_0.95 1
6 3_0.05 1
7 3_0.5 1
8 3_0.95 1
9 4_0.05 1
In [9]:
sample.shape
Out[9]:
(935010, 2)

训练集数据量为9列,969640行

测试集数据量为8列,311670行

sample为2列,935010行

In [10]:
#查看训练集各列数据类型
train.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 865750 entries, 0 to 865749
Data columns (total 9 columns):
Id                865750 non-null int64
County            785750 non-null object
Province_State    819000 non-null object
Country_Region    865750 non-null object
Population        865750 non-null int64
Weight            865750 non-null float64
Date              865750 non-null object
Target            865750 non-null object
TargetValue       865750 non-null float64
dtypes: float64(2), int64(2), object(5)
memory usage: 59.4+ MB

1.3 检测缺失值

In [11]:
train.isnull().sum()
Out[11]:
Id                    0
County            80000
Province_State    46750
Country_Region        0
Population            0
Weight                0
Date                  0
Target                0
TargetValue           0
dtype: int64
In [12]:
test.isnull().sum()
Out[12]:
ForecastId            0
County            28800
Province_State    16830
Country_Region        0
Population            0
Weight                0
Date                  0
Target                0
dtype: int64
In [13]:
sample.isnull().sum()
Out[13]:
ForecastId_Quantile    0
TargetValue            0
dtype: int64

训练集中Countyh和Province_State字段存在缺失

测试集中Countyh和Province_State字段存在缺失

sample 中无缺失

2. 数据可视化

In [14]:
fig = px.pie(train, values='TargetValue', names='Target')
fig.update_traces(textposition='inside')
fig.update_layout(uniformtext_minsize=12, uniformtext_mode='hide')
fig.show()

训练集确诊数占全部数量的94.2%,死亡数占5.79%

In [15]:
fig = px.pie(train, values='TargetValue', names='Country_Region')
fig.update_traces(textposition='inside')
fig.update_layout(uniformtext_minsize=12, uniformtext_mode='hide')
fig.show()

训练集中美国新冠肺炎确诊和死亡总数量最多是美国(54.9%),其次是巴西(4.28%),俄罗斯(3.77%)。

In [16]:
#查看训练集的最近日期
last_date = train.Date.max()
df_countries = train[train['Date']==last_date]
df_countries = df_countries.groupby('Country_Region', as_index=False)['TargetValue'].sum()
df_countries = df_countries.nlargest(10,'TargetValue')
df_trend = train.groupby(['Date','Country_Region'], as_index=False)['TargetValue'].sum()
df_trend = df_trend.merge(df_countries, on='Country_Region')
df_trend.rename(columns={'Country_Region':'Country', 'TargetValue_x':'Cases'}, inplace=True)
In [17]:
px.line(df_trend, x='Date', y='Cases', color='Country',title='新冠肺炎最严重10个国家确诊/死亡总数量折线图 ')

由折线图可知,从2020年2月2日开始到2020年6月7日,美国自3月下旬开始迅速爆发

3. 数据清洗

  • 缺失值处理,County和Province_State存在缺失,使用Province_State填充County,使用Country_Region填充Province_State
  • Country_Region,Target为分类型变量(ConfirmedCases,Fatalities),对其进行数据编码
  • 对日期进行处理
  • 选择模型预测的特征(Country_Region,Population,Weight,Target)进行预测
In [18]:
#缺失值处理
train.County.fillna(train.Province_State, inplace=True)
test.County.fillna(test.Province_State, inplace=True)

train.Province_State.fillna(train.Country_Region, inplace=True)
test.Province_State.fillna(test.Country_Region, inplace=True)

train.isnull().sum()
Out[18]:
Id                    0
County            46750
Province_State        0
Country_Region        0
Population            0
Weight                0
Date                  0
Target                0
TargetValue           0
dtype: int64
In [19]:
#数据编码
from sklearn.preprocessing import LabelEncoder
labelencoder = LabelEncoder()
train['Country_Region'] = labelencoder.fit_transform(train['Country_Region'])
train['Target'] = labelencoder.fit_transform(train['Target'])

test['Country_Region'] = labelencoder.fit_transform(test['Country_Region'])
test['Target'] = labelencoder.fit_transform(test['Target'])
In [20]:
#对日期进行处理
train['Date'] = pd.to_datetime(train['Date'], infer_datetime_format=True)
test['Date'] = pd.to_datetime(test['Date'], infer_datetime_format=True)

train.loc[:, 'Date'] = train.Date.dt.strftime("%Y%m%d")
train.loc[:, 'Date'] = train['Date'].astype(int)

test.loc[:, 'Date'] = test.Date.dt.strftime("%Y%m%d")
test.loc[:, 'Date'] = test['Date'].astype(int)
In [21]:
X_train=train.drop(['Id', 'County', 'Province_State','TargetValue'],axis=1)
y_train=train['TargetValue']
In [22]:
df_test=test.drop(columns=['County','Province_State','ForecastId'])

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X_train, y_train, test_size=0.2, random_state=0)
In [23]:
X_train.head()
Out[23]:
Country_Region Population Weight Date Target
660027 173 415759 0.772925 20200205 1
649230 173 164292 0.083268 20200517 0
98902 173 13024 0.105545 20200408 0
514454 173 124371 0.085244 20200504 0
720779 173 17297 1.024764 20200206 1
In [24]:
y_train.head()
Out[24]:
660027    0.0
649230    0.0
98902     1.0
514454    4.0
720779    0.0
Name: TargetValue, dtype: float64
In [25]:
#保存清洗后数据集
df = train.drop(['Id', 'County', 'Province_State'],axis=1)
In [26]:
df.to_csv('./result.csv') 
In [27]:
#####进行模型预测

4. 数据格式转换

4.1. 导库

In [28]:
import pandas as pd
from sklearn.preprocessing import OrdinalEncoder, StandardScaler

4.2. 数据简要摘要

In [29]:
train = pd.read_csv('.\\covid19-global-forecasting-week-5\\train.csv')
train_x = train.dropna()
train.describe()
Out[29]:
Id Population Weight TargetValue
count 865750.000000 8.657500e+05 865750.000000 865750.000000
mean 484805.500000 2.720127e+06 0.530870 11.171695
std 279911.116801 3.477771e+07 0.451909 277.581136
min 1.000000 8.600000e+01 0.047491 -10034.000000
25% 242388.250000 1.213300e+04 0.096838 0.000000
50% 484805.500000 3.053100e+04 0.349413 0.000000
75% 727222.750000 1.056120e+05 0.968379 0.000000
max 969610.000000 1.395773e+09 2.239186 36163.000000
In [30]:
test = pd.read_csv('.\\covid19-global-forecasting-week-5\\test.csv')
test_x = test.dropna()
test.describe()
Out[30]:
ForecastId Population Weight
count 311670.000000 3.116700e+05 311670.000000
mean 155835.500000 2.720127e+06 0.530870
std 89971.523537 3.477775e+07 0.451910
min 1.000000 8.600000e+01 0.047491
25% 77918.250000 1.213300e+04 0.096838
50% 155835.500000 3.053100e+04 0.349413
75% 233752.750000 1.056120e+05 0.968379
max 311670.000000 1.395773e+09 2.239186
In [31]:
train_x.head(10)
Out[31]:
Id County Province_State Country_Region Population Weight Date Target TargetValue
60500 67761 Autauga Alabama US 55869 0.091485 2020-01-23 ConfirmedCases 0.0
60501 67762 Autauga Alabama US 55869 0.914848 2020-01-23 Fatalities 0.0
60502 67763 Autauga Alabama US 55869 0.091485 2020-01-24 ConfirmedCases 0.0
60503 67764 Autauga Alabama US 55869 0.914848 2020-01-24 Fatalities 0.0
60504 67765 Autauga Alabama US 55869 0.091485 2020-01-25 ConfirmedCases 0.0
60505 67766 Autauga Alabama US 55869 0.914848 2020-01-25 Fatalities 0.0
60506 67767 Autauga Alabama US 55869 0.091485 2020-01-26 ConfirmedCases 0.0
60507 67768 Autauga Alabama US 55869 0.914848 2020-01-26 Fatalities 0.0
60508 67769 Autauga Alabama US 55869 0.091485 2020-01-27 ConfirmedCases 0.0
60509 67770 Autauga Alabama US 55869 0.914848 2020-01-27 Fatalities 0.0

4.3. 数据信息类型转换

In [32]:
# 特征工程
# 处理日期信息
train['Date'] = pd.DatetimeIndex(train['Date'])
test['Date'] = pd.DatetimeIndex(test['Date'])
# 增加月份特征
train['Month'] = train['Date'].dt.month
test['Month'] = test['Date'].dt.month
# 增加日特征
train['dayofyear'] = train['Date'].dt.dayofyear
train['quarter'] = train['Date'].dt.quarter
train['weekofyear'] = train['Date'].dt.weekofyear

test['dayofyear'] = test['Date'].dt.dayofyear
test['quarter'] = test['Date'].dt.quarter
test['weekofyear'] = test['Date'].dt.weekofyear

train_x = train.dropna()
test_x = test.dropna()

train_x.head()
test_x.head()
Out[32]:
ForecastId County Province_State Country_Region Population Weight Date Target Month dayofyear quarter weekofyear
21780 21781 Autauga Alabama US 55869 0.091485 2020-04-27 ConfirmedCases 4 118 2 18
21781 21782 Autauga Alabama US 55869 0.914848 2020-04-27 Fatalities 4 118 2 18
21782 21783 Autauga Alabama US 55869 0.091485 2020-04-28 ConfirmedCases 4 119 2 18
21783 21784 Autauga Alabama US 55869 0.914848 2020-04-28 Fatalities 4 119 2 18
21784 21785 Autauga Alabama US 55869 0.091485 2020-04-29 ConfirmedCases 4 120 2 18
In [33]:
columns = ['Country_Region', 'Target']
enc = OrdinalEncoder()
train[columns] = enc.fit_transform(train[columns])
test[columns] = enc.transform(test[columns])
In [34]:
Y = train['TargetValue']
train_x = train.drop(['Id', 'County', 'Province_State', 'TargetValue', 'Date'], axis=1)
test_x = test.drop(['County','Province_State','ForecastId', 'Date'], axis=1)
print(train_x)
print(test_x)
        Country_Region  Population    Weight  Target  Month  dayofyear  \
0                  0.0    27657145  0.058359     0.0      1         23   
1                  0.0    27657145  0.583587     1.0      1         23   
2                  0.0    27657145  0.058359     0.0      1         24   
3                  0.0    27657145  0.583587     1.0      1         24   
4                  0.0    27657145  0.058359     0.0      1         25   
5                  0.0    27657145  0.583587     1.0      1         25   
6                  0.0    27657145  0.058359     0.0      1         26   
7                  0.0    27657145  0.583587     1.0      1         26   
8                  0.0    27657145  0.058359     0.0      1         27   
9                  0.0    27657145  0.583587     1.0      1         27   
10                 0.0    27657145  0.058359     0.0      1         28   
11                 0.0    27657145  0.583587     1.0      1         28   
12                 0.0    27657145  0.058359     0.0      1         29   
13                 0.0    27657145  0.583587     1.0      1         29   
14                 0.0    27657145  0.058359     0.0      1         30   
15                 0.0    27657145  0.583587     1.0      1         30   
16                 0.0    27657145  0.058359     0.0      1         31   
17                 0.0    27657145  0.583587     1.0      1         31   
18                 0.0    27657145  0.058359     0.0      2         32   
19                 0.0    27657145  0.583587     1.0      2         32   
20                 0.0    27657145  0.058359     0.0      2         33   
21                 0.0    27657145  0.583587     1.0      2         33   
22                 0.0    27657145  0.058359     0.0      2         34   
23                 0.0    27657145  0.583587     1.0      2         34   
24                 0.0    27657145  0.058359     0.0      2         35   
25                 0.0    27657145  0.583587     1.0      2         35   
26                 0.0    27657145  0.058359     0.0      2         36   
27                 0.0    27657145  0.583587     1.0      2         36   
28                 0.0    27657145  0.058359     0.0      2         37   
29                 0.0    27657145  0.583587     1.0      2         37   
...                ...         ...       ...     ...    ...        ...   
865720           186.0    14240168  0.060711     0.0      5        133   
865721           186.0    14240168  0.607106     1.0      5        133   
865722           186.0    14240168  0.060711     0.0      5        134   
865723           186.0    14240168  0.607106     1.0      5        134   
865724           186.0    14240168  0.060711     0.0      5        135   
865725           186.0    14240168  0.607106     1.0      5        135   
865726           186.0    14240168  0.060711     0.0      5        136   
865727           186.0    14240168  0.607106     1.0      5        136   
865728           186.0    14240168  0.060711     0.0      5        137   
865729           186.0    14240168  0.607106     1.0      5        137   
865730           186.0    14240168  0.060711     0.0      5        138   
865731           186.0    14240168  0.607106     1.0      5        138   
865732           186.0    14240168  0.060711     0.0      5        139   
865733           186.0    14240168  0.607106     1.0      5        139   
865734           186.0    14240168  0.060711     0.0      5        140   
865735           186.0    14240168  0.607106     1.0      5        140   
865736           186.0    14240168  0.060711     0.0      5        141   
865737           186.0    14240168  0.607106     1.0      5        141   
865738           186.0    14240168  0.060711     0.0      5        142   
865739           186.0    14240168  0.607106     1.0      5        142   
865740           186.0    14240168  0.060711     0.0      5        143   
865741           186.0    14240168  0.607106     1.0      5        143   
865742           186.0    14240168  0.060711     0.0      5        144   
865743           186.0    14240168  0.607106     1.0      5        144   
865744           186.0    14240168  0.060711     0.0      5        145   
865745           186.0    14240168  0.607106     1.0      5        145   
865746           186.0    14240168  0.060711     0.0      5        146   
865747           186.0    14240168  0.607106     1.0      5        146   
865748           186.0    14240168  0.060711     0.0      5        147   
865749           186.0    14240168  0.607106     1.0      5        147   

        quarter  weekofyear  
0             1           4  
1             1           4  
2             1           4  
3             1           4  
4             1           4  
5             1           4  
6             1           4  
7             1           4  
8             1           5  
9             1           5  
10            1           5  
11            1           5  
12            1           5  
13            1           5  
14            1           5  
15            1           5  
16            1           5  
17            1           5  
18            1           5  
19            1           5  
20            1           5  
21            1           5  
22            1           6  
23            1           6  
24            1           6  
25            1           6  
26            1           6  
27            1           6  
28            1           6  
29            1           6  
...         ...         ...  
865720        2          20  
865721        2          20  
865722        2          20  
865723        2          20  
865724        2          20  
865725        2          20  
865726        2          20  
865727        2          20  
865728        2          20  
865729        2          20  
865730        2          20  
865731        2          20  
865732        2          21  
865733        2          21  
865734        2          21  
865735        2          21  
865736        2          21  
865737        2          21  
865738        2          21  
865739        2          21  
865740        2          21  
865741        2          21  
865742        2          21  
865743        2          21  
865744        2          21  
865745        2          21  
865746        2          22  
865747        2          22  
865748        2          22  
865749        2          22  

[865750 rows x 8 columns]
        Country_Region  Population    Weight  Target  Month  dayofyear  \
0                  0.0    27657145  0.058359     0.0      4        118   
1                  0.0    27657145  0.583587     1.0      4        118   
2                  0.0    27657145  0.058359     0.0      4        119   
3                  0.0    27657145  0.583587     1.0      4        119   
4                  0.0    27657145  0.058359     0.0      4        120   
5                  0.0    27657145  0.583587     1.0      4        120   
6                  0.0    27657145  0.058359     0.0      4        121   
7                  0.0    27657145  0.583587     1.0      4        121   
8                  0.0    27657145  0.058359     0.0      5        122   
9                  0.0    27657145  0.583587     1.0      5        122   
10                 0.0    27657145  0.058359     0.0      5        123   
11                 0.0    27657145  0.583587     1.0      5        123   
12                 0.0    27657145  0.058359     0.0      5        124   
13                 0.0    27657145  0.583587     1.0      5        124   
14                 0.0    27657145  0.058359     0.0      5        125   
15                 0.0    27657145  0.583587     1.0      5        125   
16                 0.0    27657145  0.058359     0.0      5        126   
17                 0.0    27657145  0.583587     1.0      5        126   
18                 0.0    27657145  0.058359     0.0      5        127   
19                 0.0    27657145  0.583587     1.0      5        127   
20                 0.0    27657145  0.058359     0.0      5        128   
21                 0.0    27657145  0.583587     1.0      5        128   
22                 0.0    27657145  0.058359     0.0      5        129   
23                 0.0    27657145  0.583587     1.0      5        129   
24                 0.0    27657145  0.058359     0.0      5        130   
25                 0.0    27657145  0.583587     1.0      5        130   
26                 0.0    27657145  0.058359     0.0      5        131   
27                 0.0    27657145  0.583587     1.0      5        131   
28                 0.0    27657145  0.058359     0.0      5        132   
29                 0.0    27657145  0.583587     1.0      5        132   
...                ...         ...       ...     ...    ...        ...   
311640           186.0    14240168  0.060711     0.0      5        148   
311641           186.0    14240168  0.607106     1.0      5        148   
311642           186.0    14240168  0.060711     0.0      5        149   
311643           186.0    14240168  0.607106     1.0      5        149   
311644           186.0    14240168  0.060711     0.0      5        150   
311645           186.0    14240168  0.607106     1.0      5        150   
311646           186.0    14240168  0.060711     0.0      5        151   
311647           186.0    14240168  0.607106     1.0      5        151   
311648           186.0    14240168  0.060711     0.0      5        152   
311649           186.0    14240168  0.607106     1.0      5        152   
311650           186.0    14240168  0.060711     0.0      6        153   
311651           186.0    14240168  0.607106     1.0      6        153   
311652           186.0    14240168  0.060711     0.0      6        154   
311653           186.0    14240168  0.607106     1.0      6        154   
311654           186.0    14240168  0.060711     0.0      6        155   
311655           186.0    14240168  0.607106     1.0      6        155   
311656           186.0    14240168  0.060711     0.0      6        156   
311657           186.0    14240168  0.607106     1.0      6        156   
311658           186.0    14240168  0.060711     0.0      6        157   
311659           186.0    14240168  0.607106     1.0      6        157   
311660           186.0    14240168  0.060711     0.0      6        158   
311661           186.0    14240168  0.607106     1.0      6        158   
311662           186.0    14240168  0.060711     0.0      6        159   
311663           186.0    14240168  0.607106     1.0      6        159   
311664           186.0    14240168  0.060711     0.0      6        160   
311665           186.0    14240168  0.607106     1.0      6        160   
311666           186.0    14240168  0.060711     0.0      6        161   
311667           186.0    14240168  0.607106     1.0      6        161   
311668           186.0    14240168  0.060711     0.0      6        162   
311669           186.0    14240168  0.607106     1.0      6        162   

        quarter  weekofyear  
0             2          18  
1             2          18  
2             2          18  
3             2          18  
4             2          18  
5             2          18  
6             2          18  
7             2          18  
8             2          18  
9             2          18  
10            2          18  
11            2          18  
12            2          18  
13            2          18  
14            2          19  
15            2          19  
16            2          19  
17            2          19  
18            2          19  
19            2          19  
20            2          19  
21            2          19  
22            2          19  
23            2          19  
24            2          19  
25            2          19  
26            2          19  
27            2          19  
28            2          20  
29            2          20  
...         ...         ...  
311640        2          22  
311641        2          22  
311642        2          22  
311643        2          22  
311644        2          22  
311645        2          22  
311646        2          22  
311647        2          22  
311648        2          22  
311649        2          22  
311650        2          23  
311651        2          23  
311652        2          23  
311653        2          23  
311654        2          23  
311655        2          23  
311656        2          23  
311657        2          23  
311658        2          23  
311659        2          23  
311660        2          23  
311661        2          23  
311662        2          23  
311663        2          23  
311664        2          24  
311665        2          24  
311666        2          24  
311667        2          24  
311668        2          24  
311669        2          24  

[311670 rows x 8 columns]
In [35]:
st = StandardScaler()
st.fit(train_x['Population'].values.reshape(-1, 1))
train_x['Population'] = st.transform(train_x['Population'].values.reshape(-1, 1))

print(train_x.head())
   Country_Region  Population    Weight  Target  Month  dayofyear  quarter  \
0             0.0    0.717041  0.058359     0.0      1         23        1   
1             0.0    0.717041  0.583587     1.0      1         23        1   
2             0.0    0.717041  0.058359     0.0      1         24        1   
3             0.0    0.717041  0.583587     1.0      1         24        1   
4             0.0    0.717041  0.058359     0.0      1         25        1   

   weekofyear  
0           4  
1           4  
2           4  
3           4  
4           4  
In [36]:
st2 = StandardScaler()
Y = st2.fit_transform(train['TargetValue'].values.reshape(-1, 1))
Y = Y.reshape(Y.shape[0])
print(train_x.head())
print(Y)
   Country_Region  Population    Weight  Target  Month  dayofyear  quarter  \
0             0.0    0.717041  0.058359     0.0      1         23        1   
1             0.0    0.717041  0.583587     1.0      1         23        1   
2             0.0    0.717041  0.058359     0.0      1         24        1   
3             0.0    0.717041  0.583587     1.0      1         24        1   
4             0.0    0.717041  0.058359     0.0      1         25        1   

   weekofyear  
0           4  
1           4  
2           4  
3           4  
4           4  
[-0.04024662 -0.04024662 -0.04024662 ... -0.04024662 -0.04024662
 -0.04024662]

5. 模型构建、训练和预测

5.0 Pinball Loss 函数构建

我们首先构建函数"pinball", 用于计算真实值与预测值之间的pinball loss

In [37]:
def pinball(y_true, y_pred):
    i = 4
    tao = (i + 1) / 10
    pin = np.mean(np.maximum(y_true - y_pred, 0) * tao +
                 np.maximum(y_pred - y_true, 0) * (1 - tao))
    return pin

5.1. RandomForestRegressor模型

5.1.1. 模型构建及训练

In [38]:
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor, VotingRegressor
x_train, x_test, y_train, y_test = train_test_split(train_x, Y, test_size=0.2, shuffle=False)
In [39]:
rf = RandomForestRegressor(n_jobs=-1)
rf_model = rf.fit(x_train, y_train)
y_pred = rf_model.predict(x_test)
print(y_pred)
[-0.04024662 -0.04024662 -0.04024662 ...  0.01771845  2.52945399
  0.02895841]

5.1.2. 模型预测

我们先采用MSE的方法,来对模型进行评估:

In [40]:
import numpy as np
from sklearn.metrics import mean_squared_error
np.sqrt(mean_squared_error(y_pred, y_test))
Out[40]:
1.7083738535367554

然后我们采用Pinball loss,对模型进行评估:

In [41]:
pinball(y_pred, y_test)
Out[41]:
0.034533338863260814
In [42]:
pred_rf = st2.inverse_transform(y_pred.reshape(-1, 1))
pred_rf = pred_rf.reshape(pred_rf.shape[0])
pred_rf
Out[42]:
array([-1.77635684e-15, -3.99680289e-13,  1.59872116e-14, ...,
        1.60900000e+01,  7.13300000e+02,  1.92100000e+01])

5.2. XGBRegressor模型

5.2.1. 模型构建、训练以及预测

In [43]:
from xgboost import XGBRegressor
In [44]:
xgb = XGBRegressor(n_jobs=-1)
xgb_model = xgb.fit(x_train, y_train)
y_pred_xgb = xgb_model.predict(x_test)
np.sqrt(mean_squared_error(y_pred_xgb, y_test))
Out[44]:
1.7659519202475484

可见,MSE方法评估的结果是1.766;
我们再尝试使用Pinball loss对模型进行评估:

In [46]:
pinball(y_pred_xgb, y_test)
Out[46]:
0.03413603009969335
In [47]:
pred_xgb = st2.inverse_transform(y_pred_xgb.reshape(-1, 1))
pred_xgb = pred_xgb.reshape(pred_xgb.shape[0])
pred_xgb
Out[47]:
array([ 5.3799832e-03, -7.0992164e-02,  5.3799832e-03, ...,
        6.8534348e+01,  8.0331354e+02,  6.8534348e+01], dtype=float32)

5.3. LinearRegression模型

5.3.1. 模型构建、训练以及预测

In [48]:
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score
In [49]:
lm = LinearRegression()
linear_model = lm.fit(x_train, y_train)
y_pred_linear = linear_model.predict(x_test)
np.sqrt(mean_squared_error(y_pred_linear, y_test))
Out[49]:
1.8325445116270573

可见,MSE方法评估的结果是1.833;
我们再尝试使用Pinball loss对模型进行评估:

In [50]:
pinball(y_pred_linear, y_test)
Out[50]:
0.0451407989237877
In [51]:
pred_ln = st2.inverse_transform(y_pred_linear.reshape(-1, 1))
pred_ln = pred_ln.reshape(pred_ln.shape[0])
pred_ln
Out[51]:
array([ 9.42427565, -8.80138242,  9.73187737, ...,  8.88628327,
       32.0709669 ,  9.19388499])

5.4. 模型间效果的比较分析

5.4.1. MSE评估方法下的模型效果比较

综合上述结果,我们对各个模型的效果比较如下:

| Model | MSE |
| -------- | -------- |
| RandomForestRegressor | 1.687 |
| XGBRegressor | 1.766 |
| LinearRegression | 1.833 |

通过对上述模型之间的比较分析,我们发现,
RandomForestRegressor的效果最好;
LinearRegression的效果最差;
XGBRegressor的效果在三个模型之间。

In [52]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X_train, y_train, test_size=0.2, random_state=0)

5.4.2. Pinball Loss评估方法下的模型效果比较

综合上述结果,我们对各个模型的效果比较如下:

| Model | Pinball Loss |
| -------- | -------- |
| RandomForestRegressor | 0.035 |
| XGBRegressor | 0.034 |
| LinearRegression | 0.045 |

通过对上述模型之间的比较分析,我们发现,
XGBRegressor的效果最好;
LinearRegression的效果最差;
RandomForestRegressor的效果在三个模型之间。

最终,我们得出的结论是,RandomForestRegressor和XGBRegressor模型在COVID19预测任务中都具有一定的竞争力,
而LinearRegression相较于之前两种模型而言,其效果较差。