3.4 训练模型

ML_TRAIN 例程在训练数据集上运行时会生成经过训练的机器学习 (ML) 模型。

ML_TRAIN 支持分类、回归和预测模型的训练。分类模型用于预测离散值。回归模型用于预测连续值。预测模型用于为日期和时间数据创建时间序列预测。

训练模型所需的时间可能需要几分钟到几小时,具体取决于数据集中的行数和列数、指定的 ML_TRAIN 参数以及 HeatWave Cluster 的大小。HeatWave ML 支持最大 10 GB 的表,最多 1 亿行和 900 列。

ML_TRAINMODEL_CATALOG在表 中存储机器学习模型 。请参阅 第 3.9.1 节,“模型目录”

有关 ML_TRAIN 选项说明,请参阅第 3.10.1 节,“ML_TRAIN”

使用的训练数据集 ML_TRAIN 必须驻留在 MySQL 数据库系统的表中。有关示例训练数据集,请参阅示例数据

以下示例 ML_TRAINheatwaveml_bench.census_train训练数据集上运行:

CALL sys.ML_TRAIN('heatwaveml_bench.census_train', 'revenue', 
JSON_OBJECT('task', 'classification'), @census_model);

在哪里:

  • heatwaveml_bench.census_train是包含训练数据集 ( schema_name.table_name) 的表的完全限定名称。

  • revenue是目标列的名称,其中包含参考标准值。

  • JSON_OBJECT('task', 'classification') 指定机器学习任务类型。支持的类型是classification(默认) regression、 和 forecasting

    NULLJSON_OBJECT如果您打算使用默认classification任务类型 ,则可以指定 。

    使用regression任务类型时,只允许使用数字目标列。

    使用forecasting任务类型时,进一步添加键值对来指定索引列和要预测的列。有关说明,请参见 第 3.8 节“预测”

  • @census_model是用户定义的会话变量的名称,它在连接期间存储模型句柄。用户变量写为 . 本指南中的一些示例 用作变量名称。允许用户定义变量的任何有效名称(例如, )。 @var_name@census_model@my_model

模型训练好后 ML_TRAIN ,存放在用户的模型目录中。要检索生成的模型句柄,查询指定的会话变量;例如:

mysql> SELECT @census_model;
+--------------------------------------------------+
| @census_model                                    |
+--------------------------------------------------+
| heatwaveml_bench.census_train_user1_1636729526   |
+--------------------------------------------------+
小费

在使用与执行 相同的连接时 ML_TRAIN,您可以指定会话变量(例如 @census_model)代替其他 HeatWave ML 例程中的模型句柄,但会话变量数据会在当前会话终止时丢失。如果需要查找模型句柄,可以通过查询模型目录表来实现。请参阅第 3.9.8 节,“模型句柄”

可以使用 ML_SCORE 例程评估经过训练的模型的质量和可靠性。有关详细信息,请参阅 第 3.9.6 节“评分模型”。从 MySQL 8.0.30 开始, ML_TRAIN 如果经过训练的模型得分较低,则显示以下消息:Model Has a low training score, expect low quality model explanations

高级 ML_TRAIN 选项

ML_TRAIN 例程提供了可用于影响模型选择和训练的高级选项。

  • model_list选项允许指定要训练的模型类型。如果指定了不止一种模型类型,则会从列表中选择最佳模型类型。有关受支持模型类型的列表,请参阅模型类型。该选项不能与 exclude_model_list选项一起使用。

    以下示例训练 XGBClassifieror LGBMClassifier模型。

    CALL sys.ml_train('heatwaveml_bench.census_train', 'revenue', 
    JSON_OBJECT('task','classification', 'model_list', 
    JSON_ARRAY('XGBClassifier', 'LGBMClassifier')), @census_model);
  • exclude_model_list选项指定不应训练的模型类型。指定的模型类型不在考虑之列。有关您可以指定的模型类型的列表,请参阅 模型类型。该选项不能与 model_list选项一起使用。

    以下示例排除了 LogisticRegressionGaussianNB模型。

    CALL sys.ml_train('heatwaveml_bench.census_train', 'revenue', 
    JSON_OBJECT('task','classification', 
    'exclude_model_list', JSON_ARRAY('LogisticRegression', 'GaussianNB')), 
    @census_model);
  • optimization_metric选项指定要优化的评分指标。有关受支持指标的列表,请参阅 评分指标

    以下示例针对 neg_log_loss指标进行了优化。

    CALL sys.ml_train('heatwaveml_bench.census_train', 'revenue', 
    JSON_OBJECT('task','classification', 'optimization_metric', 'neg_log_loss'), 
    @census_model);
  • exclude_column_list选项指定在训练模型时要排除的特征列。

    以下示例在 'age'为数据集训练模型时将该列排除在外census

    CALL sys.ml_train('heatwaveml_bench.census_train', 'revenue', 
    JSON_OBJECT('task','classification', 'exclude_column_list', JSON_ARRAY('age')), 
    @census_model);