ACT 模型训练指南
本文介绍如何使用艾欧智能发布的 Docker 镜像 ioaitech/train_act:cuda 训练 ACT 模型。文中的输入输出目录、参数名和默认值与镜像内约定一致。
文档默认以 Docker Hub 上的 ioaitech/train_act:cuda 为例。若你在中国大陆访问 Docker Hub 较慢,可使用华为云容器镜像服务同步地址,将镜像名前缀替换为 swr.cn-east-3.myhuaweicloud.com/ioaitech/,例如:swr.cn-east-3.myhuaweicloud.com/ioaitech/train_act:cuda。docker run 的其余参数不变。
适用范围
ACT 适合任务边界清晰、动作模式相对稳定的模仿学习场景。若你的目标是先把单任务训练链路稳定跑通,再逐步调参,ACT 仍然是很务实的选择。
本指南默认你已经在艾欧数据平台完成数据标注,并导出了 LeRobot 格式数据集。
一键开始训练
前置条件
- Linux 主机
- NVIDIA 驱动安装正常
- Docker 可用
docker run --gpus all可正常访问 GPU
建议先做一次快速自检:
docker run --rm --gpus all nvidia/cuda:12.1.0-base-ubuntu22.04 nvidia-smi
最小可运行命令
将本地 LeRobot 数据集挂载到 /data/input,将训练输出挂载到 /data/output:
docker run --rm --gpus all \
-v /path/to/lerobot_dataset:/data/input \
-v /path/to/output:/data/output \
ioaitech/train_act:cuda \
--run_name act_demo \
--task_name demo_task \
--num_epochs 1 \
--batch_size 8
这条命令适合先验证数据挂载、格式识别和训练流程是否打通。确认链路没有问题后,再把训练轮数和参数调到正式值。
推荐的一键训练脚本
下面这组参数更接近实际训练场景,适合作为公开文档中的首选模板:
docker run --rm --gpus all \
-v /path/to/lerobot_dataset:/data/input \
-v /path/to/output:/data/output \
ioaitech/train_act:cuda \
--run_name lemon_act_v1 \
--task_name pick_lemon \
--num_epochs 12000 \
--batch_size 64 \
--learning_rate 5e-5 \
--chunk_size 100 \
--kl_weight 10 \
--hidden_dim 512 \
--dim_feedforward 3200 \
--batch_mode fixed_global \
--gpus all
如果只想使用指定 GPU,可以把最后一行改成 --gpus 0、--gpus 0,1 这类形式。
数据要求
容器启动后会检查 /data/input/meta/info.json 是否存在;缺少该文件时会直接报错并退出。因此在启动训练前,请确认数据集根目录至少包含以下结构:
your_dataset/
├── meta/
│ └── info.json
├── data/
└── videos/
当前训练入口支持 LeRobot v2 和 v3 数据。脚本会自动识别版本,并在需要时完成兼容处理与中间转换。
相机字段
如果数据集的图像字段命名符合常见约定,脚本会自动从 meta/info.json 中推断相机键。若你的字段名较特殊,建议显式传入:
docker run --rm --gpus all \
-v /path/to/lerobot_dataset:/data/input \
-v /path/to/output:/data/output \
ioaitech/train_act:cuda \
--run_name multi_cam_exp \
--task_name tron2_task \
--camera_keys observation.images.cam_high,observation.images.cam_right_wrist,observation.images.cam_left_wrist \
--camera_names cam_high,cam_right_wrist,cam_left_wrist
其中:
--camera_keys是 LeRobot 数据中的图像特征键--camera_names是 ACT 训练侧使用的相机名
两者的顺序应严格对应。
常用参数
下表与镜像内 train_lerobot_to_act.py 的参数定义一致。
训练核心参数
| 参数 | 默认值 | 说明 |
|---|---|---|
--batch_size | 64 | 训练 batch size |
--num_epochs | 12000 | 主训练轮数 |
--steps | 0 | num_epochs 的兼容别名;仅在 num_epochs=0 时使用 |
--learning_rate | 5e-5 | 主学习率 |
--save_interval | 6000 | 中间存档间隔 |
--gpus | all | 使用全部 GPU,或传 0,1 这类列表 |
--batch_mode | fixed_global | 多卡时保持全局 batch 语义更接近单卡参考值 |
--num_workers | 0 | 容器环境下推荐保持默认,降低 /dev/shm 风险 |
ACT 模型参数
| 参数 | 默认值 | 说明 |
|---|---|---|
--task_name | auto | 自动从数据集中推断任务名;失败时会回退 |
--run_name | 自动生成 | checkpoint 子目录名 |
--policy_class | ACT | 一般保持默认 |
--kl_weight | 10 | KL 项权重 |
--chunk_size | 100 | 动作 chunk 长度 |
--hidden_dim | 512 | Transformer 隐层维度 |
--dim_feedforward | 3200 | 前馈层维度 |
--seed | 42 | 随机种子 |
数据桥接参数
| 参数 | 默认值 | 说明 |
|---|---|---|
--camera_keys | 自动推断 | 指定 LeRobot 图像字段 |
--camera_names | 自动生成 | 指定 ACT 相机名 |
--episode_len | 0 | 强制覆盖 episode 长度 |
--idle_threshold | 1e-4 | 静止帧过滤阈值 |
--max_episodes | 0 | 仅转换前 N 个 episode,适合 smoke test |
--convert_workers | 0 | 转换阶段并发 worker 数 |
--keep_converted_hdf5 | 关闭 | 保留中间 HDF5 文件,便于排查问题 |
输出结果
训练输出会写入挂载的 /data/output。典型产物如下:
/path/to/output/
├── checkpoints/
│ └── <run_name>/
│ ├── policy_best.ckpt
│ ├── policy_last.ckpt
│ └── dataset_stats.pkl
└── manifest.json
其中:
policy_best.ckpt是训练过程中表现较好的 checkpointpolicy_last.ckpt是最后一次保存的 checkpointdataset_stats.pkl是训练时使用的数据统计信息manifest.json记录本次训练的关键信息
多卡训练建议
当前实现已经把多卡训练路径封装在容器内部,常规情况下不需要手动改 torchrun 命令。建议遵循以下原则:
- 优先使用
--batch_mode fixed_global,更容易与单卡结果对齐 - 容器内
--num_workers建议保持0 - 初次多卡实验可以先配合
--max_episodes 10做快速验证
若你明确知道自己要追求更高吞吐,再考虑切换 --batch_mode fixed_per_gpu。
常见问题
1. 容器启动后提示找不到数据集
先检查两件事:
- 宿主机路径是否正确挂载 到
/data/input /data/input/meta/info.json是否存在
缺少 info.json 时,容器会直接退出,这通常不是训练代码问题,而是数据路径或目录层级不正确。
2. 多卡训练出现 DataLoader 异常或 NCCL 超时
优先尝试以下做法:
- 保持
--num_workers 0 - 将
--convert_workers调小到2或4 - 先缩小数据规模,用
--max_episodes做短流程验证
3. 如何设置 task_name
如果数据集中的任务字段完整,--task_name auto 通常即可工作。若你的数据集任务定义较复杂,建议显式指定任务名,便于后续管理输出目录和实验记录。
4. 首次训练速度偏慢
镜像构建阶段已经预下载了 ResNet18 权重,但数据转换和首轮数据加载仍然需要时间。只要日志在持续推进,通常属于正常现象。