该
ML_TRAIN
例程在训练数据集上运行时会生成经过训练的机器学习 (ML) 模型。
ML_TRAIN
支持分类、回归和预测模型的训练。分类模型用于预测离散值。回归模型用于预测连续值。预测模型用于为日期和时间数据创建时间序列预测。
训练模型所需的时间可能需要几分钟到几小时,具体取决于数据集中的行数和列数、指定的
ML_TRAIN
参数以及 HeatWave Cluster 的大小。HeatWave ML 支持最大 10 GB 的表,最多 1 亿行和 900 列。
ML_TRAIN
MODEL_CATALOG
在表
中存储机器学习模型
。请参阅
第 3.9.1 节,“模型目录”。
有关
ML_TRAIN
选项说明,请参阅第 3.10.1 节,“ML_TRAIN”。
使用的训练数据集
ML_TRAIN
必须驻留在 MySQL 数据库系统的表中。有关示例训练数据集,请参阅示例数据。
以下示例
ML_TRAIN
在heatwaveml_bench.census_train
训练数据集上运行:
CALL sys.ML_TRAIN('heatwaveml_bench.census_train', 'revenue',
JSON_OBJECT('task', 'classification'), @census_model);
在哪里:
heatwaveml_bench.census_train
是包含训练数据集 (schema_name.table_name
) 的表的完全限定名称。revenue
是目标列的名称,其中包含参考标准值。-
JSON_OBJECT('task', 'classification')
指定机器学习任务类型。支持的类型是classification
(默认)regression
、 和forecasting
。NULL
JSON_OBJECT
如果您打算使用默认classification
任务类型 ,则可以指定 。使用
regression
任务类型时,只允许使用数字目标列。使用
forecasting
任务类型时,进一步添加键值对来指定索引列和要预测的列。有关说明,请参见 第 3.8 节“预测”。 @census_model
是用户定义的会话变量的名称,它在连接期间存储模型句柄。用户变量写为 . 本指南中的一些示例 用作变量名称。允许用户定义变量的任何有效名称(例如, )。@
var_name
@census_model
@my_model
模型训练好后
ML_TRAIN
,存放在用户的模型目录中。要检索生成的模型句柄,查询指定的会话变量;例如:
mysql> SELECT @census_model;
+--------------------------------------------------+
| @census_model |
+--------------------------------------------------+
| heatwaveml_bench.census_train_user1_1636729526 |
+--------------------------------------------------+
在使用与执行 相同的连接时
ML_TRAIN
,您可以指定会话变量(例如
@census_model
)代替其他 HeatWave ML 例程中的模型句柄,但会话变量数据会在当前会话终止时丢失。如果需要查找模型句柄,可以通过查询模型目录表来实现。请参阅第 3.9.8 节,“模型句柄”。
可以使用
ML_SCORE
例程评估经过训练的模型的质量和可靠性。有关详细信息,请参阅
第 3.9.6 节“评分模型”。从 MySQL 8.0.30 开始,
ML_TRAIN
如果经过训练的模型得分较低,则显示以下消息:Model Has a low training score, expect low
quality model explanations
。
该
ML_TRAIN
例程提供了可用于影响模型选择和训练的高级选项。
-
该
model_list
选项允许指定要训练的模型类型。如果指定了不止一种模型类型,则会从列表中选择最佳模型类型。有关受支持模型类型的列表,请参阅模型类型。该选项不能与exclude_model_list
选项一起使用。以下示例训练
XGBClassifier
orLGBMClassifier
模型。CALL sys.ml_train('heatwaveml_bench.census_train', 'revenue', JSON_OBJECT('task','classification', 'model_list', JSON_ARRAY('XGBClassifier', 'LGBMClassifier')), @census_model);
-
该
exclude_model_list
选项指定不应训练的模型类型。指定的模型类型不在考虑之列。有关您可以指定的模型类型的列表,请参阅 模型类型。该选项不能与model_list
选项一起使用。以下示例排除了
LogisticRegression
和GaussianNB
模型。CALL sys.ml_train('heatwaveml_bench.census_train', 'revenue', JSON_OBJECT('task','classification', 'exclude_model_list', JSON_ARRAY('LogisticRegression', 'GaussianNB')), @census_model);
-
该
optimization_metric
选项指定要优化的评分指标。有关受支持指标的列表,请参阅 评分指标。以下示例针对
neg_log_loss
指标进行了优化。CALL sys.ml_train('heatwaveml_bench.census_train', 'revenue', JSON_OBJECT('task','classification', 'optimization_metric', 'neg_log_loss'), @census_model);
-
该
exclude_column_list
选项指定在训练模型时要排除的特征列。以下示例在
'age'
为数据集训练模型时将该列排除在外census
。CALL sys.ml_train('heatwaveml_bench.census_train', 'revenue', JSON_OBJECT('task','classification', 'exclude_column_list', JSON_ARRAY('age')), @census_model);