跳转至

导言

  • 数据集与数据加载器:学习如何使用torch.utils.data.Dataset和DataLoader来加载和处理数据。
  • 数据预处理:介绍常用的数据预处理方法,如归一化、数据增强等。

数据读取整体流程

🔑 DataLoader 数据读取的执行流程

当你写:

for batch in DataLoader(dataset, ...):
    ...

底层其实发生了这些步骤:


1. 初始化 DataLoader

  • 传入 dataset(必须实现 __len____getitem__
  • 传入采样方式:samplerbatch_sampler
  • 传入组装方式:collate_fn
  • 传入并行方式:num_workers

2. 开始迭代(调用 next)

当 Python 执行 next(dataloader_iter) 时:

  1. batch_sampler 提供索引

  2. batch_sampler 会决定一个 batch 要哪些样本。

  3. 如果你没传,默认逻辑是:

    • sampler 生成单个索引(默认是 range(len(dataset))RandomSampler
    • 再用 batch_size 把索引打包成 batch。

🔎 举例:

batch_sampler -> [ [0,1,2,3], [4,5,6,7], ... ]
![](https://pic.shaojiemike.top/shaojiemike/2025/09/f4c672e3d71fcc24d5ff5dbd1d142adf.png)
  1. dataset.getitem 取出样本

  2. DataLoader 会根据 batch_sampler 给的索引列表 [0,1,2,3]

  3. 调用 dataset.__getitem__(i)
  4. 得到一个个样本。

🔎 举例:

dataset[0] -> ("hello", 0)
dataset[1] -> ("world", 1)
...
  1. collate_fn 组装 batch

  2. [dataset[i] for i in indices] 的结果打包在一起。

  3. 默认行为是堆叠成张量(如果能堆叠),否则打包成 list。
  4. 如果你定义了 collate_fn,就在这里生效。

🔎 举例:

collate_fn([("hello", 0), ("world", 1)]) 
-> (tensor([[hello_ids], [world_ids]]), tensor([0,1]))
  1. 返回 batch

  2. next(dataloader_iter) 返回一个批次的数据(通常是张量或元组),交给训练循环。


3. dataset.len 的作用

  • sampler 初始化时会用到:

  • 比如 RandomSampler(dataset) 需要知道 len(dataset) 来决定总共多少个 index。

  • 所以 __len__ 是必须实现的,除非你用的是 IterableDataset(流式数据)。

🔎 用一个小 demo 验证

from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self):
        self.data = ["a", "b", "c", "d", "e"]

    def __len__(self):
        print(">>> __len__ called")
        return len(self.data)

    def __getitem__(self, idx):
        print(f">>> __getitem__({idx}) called")
        return self.data[idx]

def my_collate(batch):
    print(f">>> collate_fn called with: {batch}")
    return "-".join(batch)

loader = DataLoader(MyDataset(), batch_size=2, collate_fn=my_collate, shuffle=False)

for batch in loader:
    print(f"Got batch: {batch}")

输出(核心流程你能清晰看到):

>>> __len__ called   # 初始化 sampler 时
>>> __getitem__(0) called
>>> __getitem__(1) called
>>> collate_fn called with: ['a', 'b']
Got batch: a-b

>>> __getitem__(2) called
>>> __getitem__(3) called
>>> collate_fn called with: ['c', 'd']
Got batch: c-d

>>> __getitem__(4) called
>>> collate_fn called with: ['e']
Got batch: e

📌 总结执行顺序

当训练循环 next(dataloader_iter) 时,调用顺序是:

  1. batch_sampler 生成一批 index(依赖 dataset.__len__
  2. 对每个 index 调用 dataset.__getitem__ 取出样本
  3. 调用 collate_fn 把样本拼成 batch
  4. 返回 batch 给训练循环

要不要我帮你画一张 流程图(Dataset → Sampler → BatchSampler → getitem → Collate_fn → Batch),这样你一眼就能看到 DataLoader 的工作顺序?


🔑 DataLoader 常用参数解析

torch.utils.data.DataLoader 的参数

DataLoader(
    dataset,
    pin_memory=pin_memory,
    collate_fn=collate_fn,
    worker_init_fn=get_seed_worker(seed),
    num_workers=num_workers,
    batch_sampler=batch_sampler,
    prefetch_factor=prefetch_factor,
    persistent_workers=persistent_workers
)
  1. dataset

  2. 你的数据集对象,必须实现 __len____getitem__

  3. 比如 torchvision.datasets.CIFAR10,或者你自己继承 torch.utils.data.Dataset

  4. pin_memory

  5. 如果设为 True,DataLoader 会把张量放到 锁页内存 (pinned memory)

  6. 好处:GPU 拷贝更快(减少 CPU→GPU 传输的瓶颈)。
  7. 常用于 训练时 GPU 加速

  8. collate_fn

  9. 重点

  10. 作用:定义 如何把一个 batch 的样本拼在一起
  11. 默认行为是把 dataset 返回的单个样本打包成一个 batch,例如:

    batch = [dataset[i] for i in indices]
    return default_collate(batch)
    
    * 你可以自定义,比如:

    • 处理不同长度的序列 → padding 对齐
    • 处理 dict 类型样本
    • 丢弃坏数据
    • 例子:
    def my_collate(batch):
        texts, labels = zip(*batch)
        texts = pad_sequence(texts, batch_first=True)
        return texts, torch.tensor(labels)
    DataLoader(dataset, collate_fn=my_collate)
    
  12. worker_init_fn

  13. 每个 num_workers 子进程初始化时会调用这个函数。

  14. 常用于 随机种子设置,保证数据加载可复现。

  15. num_workers

  16. 启用多少个子进程来并行加载数据。

  17. 0 表示用主进程加载(最安全,但慢)。
  18. 大于 0 时可以大幅提升数据预处理速度(特别是 IO 瓶颈)。

  19. batch_sampler

  20. 重点

  21. 控制如何从 dataset 中采样 一个 batch 的 index
  22. sampler(单个样本采样器)不同,batch_sampler 一次返回一个 batch 的 index 列表
  23. 作用:完全接管 batch 的构造过程。
  24. 使用场景:

    • 动态 batch size(比如按序列长度分组)
    • 特殊采样策略(不规则 batch)
    • 注意:设置了 batch_sampler,就不能再传 batch_sizeshuffle

例子:

from torch.utils.data import BatchSampler, RandomSampler
sampler = RandomSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size=4, drop_last=False)
DataLoader(dataset, batch_sampler=batch_sampler)
  1. prefetch_factor

  2. 每个 worker 预取多少个 batch,默认是 2

  3. 增大能减少等待时间,但会占用更多内存。

  4. persistent_workers

  5. 如果为 True,在 epoch 之间 保持 worker 存活,避免频繁 fork 子进程。

  6. 对大规模训练(多 epoch)提升效率明显。

📌 重点对比:collate_fn vs batch_sampler

  • batch_sampler:决定 抽哪些样本(index 层面)。 👉 控制“取哪些数据”。
  • collate_fn:决定 怎么拼这些样本(数据拼接层面)。 👉 控制“如何组合成 batch”。

形象比喻:

  • batch_sampler = 菜市场采购单(告诉你买哪些菜)。
  • collate_fn = 厨师拼盘(告诉你买来的菜怎么摆到一起)。

✅ 总结

  • collate_fn:把一个 batch 的数据 打包/对齐/拼接
  • batch_sampler:控制 采样逻辑,定义每个 batch 由哪些样本组成。

DataSet

ProcessorMixin

from transformers.processing_utils import ProcessorMixin 主要是引入 🤗 Transformers 库里的一个工具类,它的作用是为各种 Processor(处理器) 提供通用的功能。

在 Hugging Face 的生态里:

  • Tokenizer:处理文本 → token id。
  • Feature Extractor:处理音频、图像等输入 → 数值特征。
  • Processor:是一个“打包器”,把 tokenizer + feature extractor 组合起来,对多模态任务(如语音识别、图像字幕生成)很方便。

🔎 ProcessorMixin 的作用

ProcessorMixin 是这些 Processor 类的基类,提供了统一的接口和通用逻辑,比如:

  1. 保存 & 加载

  2. 定义了 .save_pretrained().from_pretrained() 方法

  3. 保证 processor 可以像模型一样保存到本地或从 Hub 加载
processor.save_pretrained("my_processor")
processor = AutoProcessor.from_pretrained("my_processor")
  1. 组件管理

  2. Processor 里可能有 tokenizer、feature extractor、image processor 等

  3. ProcessorMixin 负责把这些组件的配置、权重组织在一起,保证保存/加载时一致

  4. 统一调用接口

  5. Processor 本身可以像函数一样调用

  6. 内部会自动决定把输入送给 tokenizer 还是 feature extractor
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
inputs = processor("some speech.wav", return_tensors="pt", sampling_rate=16000)

这里你不需要管用的是 tokenizer 还是 feature extractor,ProcessorMixin 把逻辑打通了。


📌 举个例子

Wav2Vec2Processor 为例:

from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
print(processor.__class__.__mro__)

继承关系里就会看到 ProcessorMixin。 它让这个 processor 具备和模型/分词器一致的“预训练保存/加载”的体验。


✅ 总结: ProcessorMixin 是 Hugging Face Transformers 里所有 Processor 类的基类,主要作用是:

  • 提供保存/加载的标准方法
  • 管理多个子组件(tokenizer / feature extractor 等)
  • 统一调用接口,简化多模态处理

(图片)数据数据读取后处理

很棒的问题 👌 我帮你逐行拆解这段 prepare_images_input 代码,尤其是 images_spatial_crop 的作用。


📌 代码功能整体概览

这段函数的目标是: 👉 根据 images_spatial_crop 里提供的裁剪信息,从 images 中取出有效的图像 patch(tile),拼成一个批量输入给模型。

最终返回:

  • total_tiles:拼接后的所有 tile,形状 [batch_all_tiles, 3, H, W] (相当于把 batch 里每个样本有效的图像块统一拼接在一起)

📌 参数说明

images: torch.FloatTensor
  • 输入的图像 tensor,一般形状是 [bs, max_num_tiles, 3, H, W]
  • bs:batch size
  • max_num_tiles:单个样本最多能有多少张图(tile)
  • 3, H, W:彩色图像通道、高度、宽度
images_spatial_crop: Optional[torch.LongTensor]
  • 输入的裁剪信息,形状 [bs, max_n_images, 2]
  • 每个元素是 (num_width_tiles, num_height_tiles),表示该图被切成多少块
  • 注意:当 (0, 0) 时,说明没有更多有效图像(提前结束)

📌 核心逻辑分解

1. 判空

if images is None or images_spatial_crop.sum() == 0:
    return None
  • 如果 images 没有传,或者所有裁剪信息都是 0(即没有 tile),直接返回 None

2. 获取 batch 维度

bs, max_n_images, _ = images_spatial_crop.shape
batch_num_tiles = [0 for _ in range(bs)]
total_tiles = []
  • bs:批大小
  • max_n_images:每个样本最多能有多少张原始图像
  • batch_num_tiles:记录每个样本里 tile 的总数
  • total_tiles:用来存放每个样本的有效 tile

3. 遍历 batch

for idx in range(bs):
    for jdx in range(max_n_images):
        num_width_tiles, num_height_tiles = images_spatial_crop[idx, jdx]
        if num_width_tiles == 0 or num_height_tiles == 0:
            break
        batch_num_tiles[idx] += (1 + num_width_tiles * num_height_tiles)

对每个样本 (idx):

  • 遍历该样本的多张图像 (jdx)
  • 取出 (num_width_tiles, num_height_tiles)

  • 如果是 (0,0),说明没图了 → 提前退出循环

  • 否则计算 tile 数:

  • 1:原图

  • num_width_tiles * num_height_tiles:切出来的 patch 数
  • 累加到 batch_num_tiles[idx]

👉 举个例子: 如果 images_spatial_crop[idx, jdx] = (2, 2)

  • tile 数 = 1 + 2*2 = 5
  • 表示:原图 + 4 个裁剪 patch

4. 收集有效 tile

total_tiles.append(images[idx, :batch_num_tiles[idx]])
  • images[idx] 中取前 batch_num_tiles[idx] 张图
  • 丢到 total_tiles 列表

5. 拼接所有 batch

total_tiles = torch.cat(total_tiles, dim=0)
  • 把所有 batch 的 tile 拼接成一个大张量
  • 形状 [batch_all_tiles, 3, H, W]

6. 检查

if total_tiles.shape[0] != sum(batch_num_tiles):
    raise AssertionError
if total_tiles.shape[0] == 0:
    return None
  • 保证拼接后的 tile 数和统计的总 tile 数一致
  • 如果 tile 数为 0 → 返回 None

📌 总结

  • images:存放所有原始图像和裁剪 patch
  • images_spatial_crop:记录每张图的裁剪情况 (num_width_tiles, num_height_tiles)

  • (0,0) 表示没有更多有效图

  • (2,2) 表示切成 2×2 个小块,加上原图,总共 5 张

最终返回: 一个 [batch_all_tiles, 3, H, W] 的 tensor,把 batch 里所有有效 tile 拼起来。


要不要我帮你画一张 图示(原图 → tile → batch 拼接),更直观地展示 images_spatial_crop 是怎么控制裁剪和拼接的?