处理不平衡数据的终极方案:Pytorch-WideDeep自定义DataLoader实战

【免费下载链接】pytorch-widedeep A flexible package for multimodal-deep-learning to combine tabular data with text and images using Wide and Deep models in Pytorch 【免费下载链接】pytorch-widedeep 项目地址: https://gitcode.com/gh_mirrors/py/pytorch-widedeep

在机器学习和深度学习任务中,不平衡数据是一个常见的挑战,它会导致模型训练偏向多数类,从而影响模型在少数类上的性能。Pytorch-WideDeep作为一个灵活的多模态深度学习框架,提供了强大的自定义DataLoader功能,帮助开发者有效处理不平衡数据问题。本文将详细介绍如何使用Pytorch-WideDeep的自定义DataLoader来解决不平衡数据问题,提升模型性能。

什么是不平衡数据?

不平衡数据指的是在分类任务中,不同类别的样本数量差异较大的情况。例如,在欺诈检测中,欺诈样本可能只占总样本的1%,而正常样本占99%。这种情况下,模型很容易倾向于预测多数类,导致少数类的识别率低下。

不平衡数据会带来诸多问题,如模型训练的偏差、评估指标的误导等。传统的解决方法包括过采样、欠采样、类别权重调整等。而Pytorch-WideDeep提供的自定义DataLoader则为处理不平衡数据提供了一种高效、灵活的方式。

Pytorch-WideDeep自定义DataLoader简介

Pytorch-WideDeep的自定义DataLoader主要包括CustomDataLoaderDataLoaderImbalanced两个类。其中,DataLoaderImbalanced是专门为处理不平衡数据设计的,它通过加权随机采样(WeightedRandomSampler)来平衡不同类别的样本数量。

DataLoaderImbalanced的核心思想是根据样本的类别权重来调整采样概率,使得少数类样本有更高的被选中概率。它的实现位于pytorch_widedeep/dataloaders.py文件中。

DataLoaderImbalanced的工作原理

  1. 计算类别权重:通过get_class_weights函数计算每个类别的权重,权重与类别样本数量成反比。
  2. 设置采样器:使用WeightedRandomSampler根据类别权重进行采样,确保每个类别都有足够的样本被选中。
  3. 动态调整:可以通过oversample_mul参数调整少数类的过采样倍数,进一步平衡数据分布。

实战:使用DataLoaderImbalanced处理不平衡数据

下面通过一个实际案例来演示如何使用Pytorch-WideDeep的DataLoaderImbalanced处理不平衡数据。我们以生物数据(bio_kdd04)为例,该数据集存在严重的类别不平衡问题。

步骤1:准备数据

首先,加载数据集并进行预处理:

import pandas as pd
from sklearn.model_selection import train_test_split
from pytorch_widedeep.datasets import load_bio_kdd04

# 加载数据
df = load_bio_kdd04(as_frame=True)
# 去除不需要的列
df.drop(columns=["EXAMPLE_ID", "BLOCK_ID"], inplace=True)
# 划分训练集、验证集和测试集
df_train, df_valid = train_test_split(df, test_size=0.2, stratify=df["target"], random_state=1)
df_valid, df_test = train_test_split(df_valid, test_size=0.5, stratify=df_valid["target"], random_state=1)

步骤2:数据预处理

使用TabPreprocessor对表格数据进行预处理:

from pytorch_widedeep.preprocessing import TabPreprocessor

continuous_cols = df.drop(columns=["target"]).columns.values.tolist()
tab_preprocessor = TabPreprocessor(continuous_cols=continuous_cols, scale=True)
X_tab_train = tab_preprocessor.fit_transform(df_train)
X_tab_valid = tab_preprocessor.transform(df_valid)
X_tab_test = tab_preprocessor.transform(df_test)

y_train = df_train["target"].values
y_valid = df_valid["target"].values
y_test = df_test["target"].values

步骤3:定义模型

使用TabMlp作为深度学习模型:

from pytorch_widedeep.models import TabMlp, WideDeep

deeptabular = TabMlp(
    column_idx=tab_preprocessor.column_idx,
    continuous_cols=tab_preprocessor.continuous_cols,
    mlp_hidden_dims=[64, 32],
)
model = WideDeep(deeptabular=deeptabular)

步骤4:使用DataLoaderImbalanced

创建DataLoaderImbalanced实例,并设置oversample_mul参数来调整过采样倍数:

from pytorch_widedeep.dataloaders import DataLoaderImbalanced
from pytorch_widedeep import Trainer
from pytorch_widedeep.metrics import Accuracy, Precision

trainer = Trainer(
    model,
    objective="binary",
    metrics=[Accuracy, Precision],
    verbose=1,
)

# 创建不平衡数据加载器,设置过采样倍数为5
train_dataloader = DataLoaderImbalanced(kwargs={"oversample_mul": 5})
eval_dataloader = DataLoaderImbalanced(kwargs={"oversample_mul": 5})

# 训练模型
trainer.fit(
    X_train={"X_tab": X_tab_train, "target": y_train},
    X_val={"X_tab": X_tab_valid, "target": y_valid},
    n_epochs=1,
    batch_size=32,
    train_dataloader=train_dataloader,
    eval_dataloader=eval_dataloader,
)

上述代码中,DataLoaderImbalanced通过oversample_mul=5将少数类样本的数量增加了5倍,从而有效平衡了训练数据。完整的示例代码可以在examples/scripts/bio_imbalanced_loader.py中找到。

自监督预训练与不平衡数据处理

除了直接使用DataLoaderImbalanced外,Pytorch-WideDeep还支持通过自监督预训练来提升模型在不平衡数据上的性能。自监督学习可以通过数据增强和对比学习等方式,帮助模型学习到更鲁棒的特征表示,从而缓解数据不平衡带来的影响。

自监督SAINT模型架构

上图展示了SAINT(Self-Attention and Inter-sample Attention Transformer)模型的自监督预训练架构。该模型通过对比损失(InfoNCE loss)和去噪损失(MSE/CE loss)来学习数据的深层特征,为后续的监督学习任务打下良好的基础。自监督预训练的实现可以参考pytorch_widedeep/self_supervised_training/目录下的代码。

总结

处理不平衡数据是机器学习任务中的一个重要挑战,Pytorch-WideDeep提供的DataLoaderImbalanced为解决这一问题提供了简单而有效的方案。通过加权随机采样和过采样倍数调整,我们可以轻松地平衡训练数据,提升模型在少数类上的性能。此外,结合自监督预训练等技术,可以进一步提升模型的泛化能力和鲁棒性。

希望本文能够帮助你更好地理解和使用Pytorch-WideDeep的自定义DataLoader功能,解决实际应用中的不平衡数据问题。如果你想了解更多细节,可以参考官方文档和示例代码。

【免费下载链接】pytorch-widedeep A flexible package for multimodal-deep-learning to combine tabular data with text and images using Wide and Deep models in Pytorch 【免费下载链接】pytorch-widedeep 项目地址: https://gitcode.com/gh_mirrors/py/pytorch-widedeep

Logo

脑启社区是一个专注类脑智能领域的开发者社区。欢迎加入社区,共建类脑智能生态。社区为开发者提供了丰富的开源类脑工具软件、类脑算法模型及数据集、类脑知识库、类脑技术培训课程以及类脑应用案例等资源。

更多推荐