FTTransformer:让Transformer在表格数据上也能‘拳拳到肉’
1. 为什么表格数据需要Transformer表格数据Tabular Data是我们日常处理最多的一种数据类型从电商的用户行为记录到金融领域的交易流水几乎无处不在。传统上这类数据的建模任务往往交给梯度提升树如XGBoost、LightGBM来处理因为它们对特征工程的要求相对较低且在小规模数据上表现优异。但深度学习模型尤其是Transformer架构在自然语言处理和计算机视觉领域的成功让我们不禁思考能否让Transformer在表格数据上也拳拳到肉这里的关键挑战在于表格数据与文本或图像数据有本质区别。文本数据天然具有序列性每个token单词或字都有明确的语义图像数据则具有局部相关性相邻像素之间关系密切。而表格数据的特征之间可能完全独立也可能存在复杂的非线性关系这种不确定性让传统的Transformer难以直接应用。FTTransformerFeature Tokenizer Transformer正是为解决这一问题而生。它通过特征令牌化Feature Tokenization将每个特征无论是连续型还是离散型都转化为一个向量表示就像把单词转化为词向量一样。这样一来原本扁平的表格数据就变成了一个特征序列Transformer的自注意力机制就能在这些特征之间建立联系。2. FTTransformer的核心设计思想2.1 特征令牌化让表格数据说Transformer的语言FTTransformer的第一个创新点在于它的特征处理方式。对于传统机器学习模型我们通常会这样处理特征连续特征直接使用或标准化离散特征one-hot编码或嵌入(Embedding)但FTTransformer采用了更统一的方式# 传统处理方式以LightGBM为例 continuous_features [age, income] categorical_features [gender, city] # FTTransformer的处理方式 all_features continuous_features categorical_features # 每个特征都会被映射为一个d维向量这种处理有三大优势统一表示无论特征类型如何最终都转化为相同维度的向量便于Transformer处理保留语义每个特征的向量可以在训练过程中学习到最适合任务的含义灵活扩展新特征的加入不会破坏现有结构2.2 针对表格数据的架构调整原始的Transformer有几个不太适合表格数据的设计FTTransformer做了针对性改进移除首个LayerNorm实验发现在表格数据上输入层立即进行归一化会损失信息引入CLS Token借鉴BERT的做法添加一个特殊的[CLS]标记其最终输出用作整个表格的表示特征交互设计允许特征token之间充分交互同时保留原始特征信息这些改动看似微小但在实际任务中却能带来显著提升。比如在植被覆盖类型预测任务中准确率从LightGBM的83%提升到了91%。3. 实战用FTTransformer预测植被覆盖类型3.1 数据准备与预处理我们使用Covertype数据集这是一个经典的多元分类数据集目标是预测美国科罗拉多州不同地区的植被类型。数据包含54个特征10个连续型44个二元型7个类别。import pandas as pd from sklearn.model_selection import train_test_split # 加载数据 df pd.read_parquet(covertype.parquet) # 划分数据集 dftmp, dftest train_test_split(df, test_size0.2, random_state42) dftrain, dfval train_test_split(dftmp, test_size0.2, random_state42) print(f训练集: {len(dftrain)}条, 验证集: {len(dfval)}条, 测试集: {len(dftest)}条)数据预处理的关键步骤连续特征标准化离散特征编码目标变量转换FTTransformer提供了一个方便的TabularPreprocessor来简化这些步骤from torchkeras.tabular import TabularPreprocessor preprocessor TabularPreprocessor( continuous_cols[Elevation, Aspect, Slope, ...], categorical_cols[Wilderness_Area, Soil_Type], target_colCover_Type, taskclassification ) # 拟合预处理管道 preprocessor.fit(dftrain) # 应用转换 dftrain_preprocessed preprocessor.transform(dftrain)3.2 模型构建与训练使用torchkeras库可以轻松构建FTTransformer模型from torchkeras.tabular.models import FTTransformerConfig, FTTransformerModel # 配置模型 config FTTransformerConfig( taskclassification, num_attn_blocks3, # Transformer块的数量 embedding_dim32, # 每个特征的嵌入维度 num_heads4 # 注意力头数 ) model FTTransformerModel(config) # 数据感知的权重初始化 model.data_aware_initialization(train_loader)训练过程采用标准的PyTorch流程但加入了早停和模型检查点等实用功能from torchkeras import KerasModel import torch # 自定义准确率计算 class Accuracy(torch.nn.Module): def __init__(self): super().__init__() self.correct torch.tensor(0) self.total torch.tensor(0) def forward(self, preds, targets): preds preds.argmax(dim-1) targets targets.reshape(-1) correct (preds targets).sum() total targets.shape[0] self.correct correct self.total total return correct.float() / total # 配置训练 keras_model KerasModel( model, optimizertorch.optim.AdamW(model.parameters(), lr1e-3), metrics_dict{acc: Accuracy()} ) # 开始训练 history keras_model.fit( train_datatrain_loader, val_dataval_loader, epochs20, patience5, monitorval_acc, modemax )3.3 性能评估与对比训练完成后我们可以在测试集上评估模型test_metrics keras_model.evaluate(test_loader) print(f测试集准确率: {test_metrics[acc]:.4f})为了展示FTTransformer的优势我们与LightGBM进行对比import lightgbm as lgb # 训练LightGBM lgb_model lgb.LGBMClassifier( n_estimators500, learning_rate0.01, num_leaves31 ) lgb_model.fit(X_train, y_train) # 评估 lgb_score lgb_model.score(X_test, y_test) print(fLightGBM准确率: {lgb_score:.4f})在植被覆盖预测任务中典型的结果对比可能是FTTransformer: 91.0% 准确率LightGBM: 83.3% 准确率这种差距在多个表格数据任务中都得到了验证说明FTTransformer确实能够超越传统方法。4. FTTransformer的适用场景与调优建议4.1 何时选择FTTransformer虽然FTTransformer表现优异但并不是所有场景都适用。根据我的经验以下情况特别适合中等规模数据1万-100万样本小数据可能过拟合超大数据需要更多计算资源特征交互复杂当特征间存在难以手动设计的复杂关系时混合特征类型同时包含连续型和离散型特征的任务需要高精度当业务场景对模型精度要求极高时4.2 关键超参数调优要让FTTransformer发挥最佳性能有几个关键参数需要注意embedding_dim特征嵌入的维度通常32-128之间num_attn_blocksTransformer块的数量2-4层通常足够num_heads注意力头数建议从4开始尝试learning_rate1e-4到1e-3之间比较合适一个实用的调优策略是from torchkeras.tuner import Tuner tuner Tuner( model_classFTTransformerModel, config_classFTTransformerConfig, target_metricval_acc ) best_config tuner.tune( train_loader, val_loader, param_space{ embedding_dim: [32, 64, 128], num_attn_blocks: [2, 3, 4], num_heads: [4, 8], learning_rate: [1e-4, 3e-4, 1e-3] }, max_trials20 )4.3 常见问题与解决方案在实际使用中可能会遇到以下问题训练不稳定尝试减小学习率增加batch size使用梯度裁剪过拟合增加dropout率使用早停添加L2正则化内存不足减小batch size降低embedding维度使用混合精度训练我在一个客户流失预测项目中就遇到过内存问题通过将embedding_dim从64降到32不仅解决了内存问题准确率还略有提升这可能是因为小维度在中等规模数据上反而更不容易过拟合。