跳到主要内容

ACT 模型训练指南

本文介绍如何使用艾欧智能发布的 Docker 镜像 ioaitech/train_act:cuda 训练 ACT 模型。文中的输入输出目录、参数名和默认值与镜像内约定一致。

国内镜像加速(华为云 SWR)

文档默认以 Docker Hub 上的 ioaitech/train_act:cuda 为例。若你在中国大陆访问 Docker Hub 较慢,可使用华为云容器镜像服务同步地址,将镜像名前缀替换为 swr.cn-east-3.myhuaweicloud.com/ioaitech/,例如:swr.cn-east-3.myhuaweicloud.com/ioaitech/train_act:cudadocker run 的其余参数不变。

适用范围

ACT 适合任务边界清晰、动作模式相对稳定的模仿学习场景。若你的目标是先把单任务训练链路稳定跑通,再逐步调参,ACT 仍然是很务实的选择。

本指南默认你已经在艾欧数据平台完成数据标注,并导出了 LeRobot 格式数据集。

一键开始训练

前置条件

  • Linux 主机
  • NVIDIA 驱动安装正常
  • Docker 可用
  • docker run --gpus all 可正常访问 GPU

建议先做一次快速自检:

docker run --rm --gpus all nvidia/cuda:12.1.0-base-ubuntu22.04 nvidia-smi

最小可运行命令

将本地 LeRobot 数据集挂载到 /data/input,将训练输出挂载到 /data/output

docker run --rm --gpus all \
-v /path/to/lerobot_dataset:/data/input \
-v /path/to/output:/data/output \
ioaitech/train_act:cuda \
--run_name act_demo \
--task_name demo_task \
--num_epochs 1 \
--batch_size 8

这条命令适合先验证数据挂载、格式识别和训练流程是否打通。确认链路没有问题后,再把训练轮数和参数调到正式值。

推荐的一键训练脚本

下面这组参数更接近实际训练场景,适合作为公开文档中的首选模板:

docker run --rm --gpus all \
-v /path/to/lerobot_dataset:/data/input \
-v /path/to/output:/data/output \
ioaitech/train_act:cuda \
--run_name lemon_act_v1 \
--task_name pick_lemon \
--num_epochs 12000 \
--batch_size 64 \
--learning_rate 5e-5 \
--chunk_size 100 \
--kl_weight 10 \
--hidden_dim 512 \
--dim_feedforward 3200 \
--batch_mode fixed_global \
--gpus all

如果只想使用指定 GPU,可以把最后一行改成 --gpus 0--gpus 0,1 这类形式。

数据要求

容器启动后会检查 /data/input/meta/info.json 是否存在;缺少该文件时会直接报错并退出。因此在启动训练前,请确认数据集根目录至少包含以下结构:

your_dataset/
├── meta/
│ └── info.json
├── data/
└── videos/

当前训练入口支持 LeRobot v2v3 数据。脚本会自动识别版本,并在需要时完成兼容处理与中间转换。

相机字段

如果数据集的图像字段命名符合常见约定,脚本会自动从 meta/info.json 中推断相机键。若你的字段名较特殊,建议显式传入:

docker run --rm --gpus all \
-v /path/to/lerobot_dataset:/data/input \
-v /path/to/output:/data/output \
ioaitech/train_act:cuda \
--run_name multi_cam_exp \
--task_name tron2_task \
--camera_keys observation.images.cam_high,observation.images.cam_right_wrist,observation.images.cam_left_wrist \
--camera_names cam_high,cam_right_wrist,cam_left_wrist

其中:

  • --camera_keys 是 LeRobot 数据中的图像特征键
  • --camera_names 是 ACT 训练侧使用的相机名

两者的顺序应严格对应。

常用参数

下表与镜像内 train_lerobot_to_act.py 的参数定义一致。

训练核心参数

参数默认值说明
--batch_size64训练 batch size
--num_epochs12000主训练轮数
--steps0num_epochs 的兼容别名;仅在 num_epochs=0 时使用
--learning_rate5e-5主学习率
--save_interval6000中间存档间隔
--gpusall使用全部 GPU,或传 0,1 这类列表
--batch_modefixed_global多卡时保持全局 batch 语义更接近单卡参考值
--num_workers0容器环境下推荐保持默认,降低 /dev/shm 风险

ACT 模型参数

参数默认值说明
--task_nameauto自动从数据集中推断任务名;失败时会回退
--run_name自动生成checkpoint 子目录名
--policy_classACT一般保持默认
--kl_weight10KL 项权重
--chunk_size100动作 chunk 长度
--hidden_dim512Transformer 隐层维度
--dim_feedforward3200前馈层维度
--seed42随机种子

数据桥接参数

参数默认值说明
--camera_keys自动推断指定 LeRobot 图像字段
--camera_names自动生成指定 ACT 相机名
--episode_len0强制覆盖 episode 长度
--idle_threshold1e-4静止帧过滤阈值
--max_episodes0仅转换前 N 个 episode,适合 smoke test
--convert_workers0转换阶段并发 worker 数
--keep_converted_hdf5关闭保留中间 HDF5 文件,便于排查问题

输出结果

训练输出会写入挂载的 /data/output。典型产物如下:

/path/to/output/
├── checkpoints/
│ └── <run_name>/
│ ├── policy_best.ckpt
│ ├── policy_last.ckpt
│ └── dataset_stats.pkl
└── manifest.json

其中:

  • policy_best.ckpt 是训练过程中表现较好的 checkpoint
  • policy_last.ckpt 是最后一次保存的 checkpoint
  • dataset_stats.pkl 是训练时使用的数据统计信息
  • manifest.json 记录本次训练的关键信息

多卡训练建议

当前实现已经把多卡训练路径封装在容器内部,常规情况下不需要手动改 torchrun 命令。建议遵循以下原则:

  • 优先使用 --batch_mode fixed_global,更容易与单卡结果对齐
  • 容器内 --num_workers 建议保持 0
  • 初次多卡实验可以先配合 --max_episodes 10 做快速验证

若你明确知道自己要追求更高吞吐,再考虑切换 --batch_mode fixed_per_gpu

常见问题

1. 容器启动后提示找不到数据集

先检查两件事:

  • 宿主机路径是否正确挂载到 /data/input
  • /data/input/meta/info.json 是否存在

缺少 info.json 时,容器会直接退出,这通常不是训练代码问题,而是数据路径或目录层级不正确。

2. 多卡训练出现 DataLoader 异常或 NCCL 超时

优先尝试以下做法:

  • 保持 --num_workers 0
  • --convert_workers 调小到 24
  • 先缩小数据规模,用 --max_episodes 做短流程验证

3. 如何设置 task_name

如果数据集中的任务字段完整,--task_name auto 通常即可工作。若你的数据集任务定义较复杂,建议显式指定任务名,便于后续管理输出目录和实验记录。

4. 首次训练速度偏慢

镜像构建阶段已经预下载了 ResNet18 权重,但数据转换和首轮数据加载仍然需要时间。只要日志在持续推进,通常属于正常现象。

实践建议

  • 先用 1 个 epoch 或少量 episode 做流程验证,再跑正式训练
  • 固定一组基础参数作为对照组,后续每次只调整少量变量
  • 不要只看最后一个 checkpoint,建议同时比较若干中间 checkpoint 的实际效果

参考资料