跳到主要内容

Pi0 与 Pi0.5 模型微调:基于 OpenPI 的完整流程

Pi0 / Pi0.5 是 Physical Intelligence 推出的视觉-语言-动作(VLA)模型。若你打算用艾欧数据平台导出的数据微调这类模型,本指南介绍基于官方 OpenPI 框架的完整流程。

为何选 OpenPI 而非 LeRobot 框架?

LeRobot 虽支持 Pi0 等多种模型,但对 Pi0 系列更推荐使用 OpenPI 官方训练框架:基于 JAX,原生多卡高性能训练,能更好发挥 Pi0 能力。


1. 准备:导出与放置数据

在艾欧平台将标注好的数据导出为 OpenPI 可识别的格式。

导出步骤

  1. 格式:在导出页选择 LeRobot v2.1export lerobot v2.1
  2. 解压:下载 .tar.gz 并在本地解压。
  3. 放置:为便于 OpenPI 读取,建议放到 Hugging Face 本地缓存目录,例如:
    # 准备目录
    mkdir -p ~/.cache/huggingface/lerobot/local/mylerobot

    # 将解压出的内容(含 meta、data/ 等)移入该目录
    mv /path/to/extracted/data/* ~/.cache/huggingface/lerobot/local/mylerobot/

字段映射参考(以 Aloha 三相机为例)

后续配置中需保证代码里的 Key 与数据字段一致。常用约定:

  • cam_high:顶视
  • cam_left_wrist / cam_right_wrist:左右腕视
  • state:机器人状态
  • action:目标动作(注意:OpenPI 默认 ALOHA 为 14 维;若你数据维度不同,请参考下方「深度排坑」)

2. 如何选训练配置?

OpenPI 的训练由配置驱动,选配置相当于选一个「最接近你机器人」的策略模板,再在其上微调。

需求场景推荐路径关注点
快速验证 / 打通链路仿真(LIBERO / ALOHA Sim)先把 Inputs/Outputs 对齐,成本最低
实机(双臂 Aloha)ALOHA Real对齐相机 Key、动作维度与夹爪控制逻辑
单臂 / 工业机械臂参考 UR5 示例先解决控制接口兼容,再谈效果
追求泛化对齐 DROID 数据学习 DROID 的 Norm Stats 等策略

简要建议:第一次跑先用仿真配置打通流程;上实机则选 ALOHA Real,并严格对齐 state/action 维度。


技术排坑:关于 14 维 vs 16 维动作向量

这是一个非常容易被忽视的“坑”。OpenPI 的默认 ALOHA 策略(aloha_policy.py)硬编码了 14 维 结构:

  • 默认结构[左臂6关节, 左夹爪1, 右臂6关节, 右夹爪1] = 14 维。
  • 常见问题:如果你使用的是 7 轴机械臂(如 [7, 1, 7, 1]),总维度会变成 16。此时如果不修改代码,多出的维度会被静默截断,导致训练出的模型完全无法控制最后两个关节。

修改建议:

  1. 检查你的 action 向量定义。
  2. aloha_policy.py 中,将所有的 :14 切片改为你的实际维度(如 :16)。
  3. 同步修改 _joint_flip_mask 的长度,确保正负号反转逻辑与你的硬件一致。

3. 编写训练配置

openpi/src/openpi/training/config.py 中为你的机器人添加微调配置。

# 示例:为你的机器人添加自定义配置
TrainConfig(
name="pi0_aloha_mylerobot",
model=pi0_config.Pi0Config(),
data=LeRobotAlohaDataConfig(
repo_id="local/mylerobot", # 指向之前放置数据的目录
assets=AssetsConfig(
assets_dir="/home/user/code/openpi/assets/pi0_aloha_mylerobot",
),
default_prompt="fold the clothes", # 任务描述,非常重要
repack_transforms=_transforms.Group(
inputs=[
_transforms.RepackTransform(
{
"images": {
"cam_high": "observation.images.cam_high",
"cam_left_wrist": "observation.images.cam_left_wrist",
"cam_right_wrist": "observation.images.cam_right_wrist",
},
"state": "observation.state",
"actions": "action",
}
)
]
),
),
weight_loader=weight_loaders.CheckpointWeightLoader("gs://openpi-assets/checkpoints/pi0_base/params"),
num_train_steps=20_000,
)

4. 开始训练:计算统计并运行

训练前必须先跑归一化统计,否则输入数值范围会错乱。

第一步:计算 Norm Stats

uv run scripts/compute_norm_stats.py --config-name pi0_aloha_mylerobot

第二步:启动微调

建议使用 JAX 模式以获得最佳性能。

单卡模式:

export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9
CUDA_VISIBLE_DEVICES=0 uv run scripts/train.py pi0_aloha_mylerobot \
--exp-name=my_first_experiment \
--overwrite

多卡并行 (FSDP):

uv run scripts/train.py pi0_aloha_mylerobot --exp-name=multi_gpu_run --fsdp-devices 4

5. 推理与部署

微调完成后可启动策略服务器,在实机或仿真中运行策略。

# 启动推理服务器,默认端口 8000
uv run scripts/serve_policy.py policy:checkpoint \
--policy.config=pi0_aloha_mylerobot \
--policy.dir=experiments/my_first_experiment/checkpoints/last

6. 常见问题(FAQ)

  • Q: 显存 OOM?
    调小 batch_size,或确认 XLA_PYTHON_CLIENT_MEM_FRACTION;多卡 FSDP 也能缓解显存压力。
  • Q: 动作异常或关节不动?
    检查 RepackTransform 映射是否正确;重点回顾上文 14 维 vs 16 维,确认是否被静默截断。
  • Q: Loss 不降?
    检查 default_prompt 是否准确;确认 compute_norm_stats 生成的统计是否生效。

已知兼容性问题与修复(lerobot 框架训练 Pi0)

使用 lerobot 框架(非 OpenPI)在本地训练 Pi0 时,lerobot 的 Pi0 实现依赖 transformers 的特定 fork,标准 pip install transformers 可能遇到以下不兼容。

问题一:v2 数据 + lerobot v0.3.3 — GemmaForCausalLM 属性访问错误

症状:

AttributeError: 'GemmaForCausalLM' object has no attribute 'embed_tokens'
AttributeError: 'GemmaForCausalLM' object has no attribute 'layers'

根因: lerobot v0.3.3 的 pi0 实现(paligemma_with_expert.py)直接在 GemmaForCausalLM 实例上访问 .embed_tokens.layers.norm。但在 transformers ≥ 4.40(含 4.51.x)中,这些属性位于嵌套子模块 .model(即 GemmaModel)下,而不再暴露在顶层模型上。

修复方法: 在 lerobot 安装目录下找到 lerobot/policies/pi0/paligemma_with_expert.py,修改以下两处:

1. 修复 embed_language_tokens 方法中的 embed_tokens 访问:

# 修改前
return self.paligemma.language_model.embed_tokens(tokens)

# 修改后
_lm = self.paligemma.language_model
_et = getattr(_lm, "embed_tokens", None) or _lm.model.embed_tokens
return _et(tokens)

2. 修复 forward 方法中 models 列表的 layers / norm 访问:

# 修改前(GemmaForCausalLM 无 .layers/.norm,访问报错)
models = [self.paligemma.language_model, self.gemma_expert.model]

# 修改后(取 .model 子模块 GemmaModel,它才有 .layers/.norm)
_pg_lm = self.paligemma.language_model
models = [getattr(_pg_lm, "model", _pg_lm), self.gemma_expert.model]

也可以用以下脚本自动应用两处 patch(在 lerobot 环境的 Python 中执行):

from pathlib import Path

path = Path(".venv/lib/python3.10/site-packages/lerobot/policies/pi0/paligemma_with_expert.py")
text = path.read_text()

old1 = "return self.paligemma.language_model.embed_tokens(tokens)"
new1 = '_lm = self.paligemma.language_model; _et = getattr(_lm, "embed_tokens", None) or _lm.model.embed_tokens; return _et(tokens)'
if old1 in text:
text = text.replace(old1, new1, 1)
print("Patched: embed_tokens")

old2 = "models = [self.paligemma.language_model, self.gemma_expert.model]"
new2 = '_pg_lm = self.paligemma.language_model; models = [getattr(_pg_lm, "model", _pg_lm), self.gemma_expert.model]'
if old2 in text:
text = text.replace(old2, new2, 1)
print("Patched: models list (layers/norm)")

path.write_text(text)
print("Done")

问题二:v3 数据 + lerobot v0.4.3 — transformers 版本检查失败

症状:

ValueError: An incorrect transformer version is used

根因: lerobot v0.4.3 的 pi0 实现在启动时会调用:

transformers.models.siglip.check.check_whether_transformers_replace_is_installed_correctly()

该函数仅存在于 lerobot 团队维护的 transformers fork(fix/lerobot_openpi 分支),标准版 transformers 没有此模块。

此外,lerobot v0.4.3 的 Pi0 还依赖:

  • GemmaRMSNorm.forward(self, x, cond=None) 的扩展参数签名(标准版只有 forward(self, x)
  • _gated_residual(x, y, gate) 辅助函数(标准版 transformers 不存在)

修复方法: 对已安装的标准 transformers 进行原地补丁(在 lerobot 环境的 Python 中执行):

from pathlib import Path
import transformers

td = Path(transformers.__file__).parent

# 1. 创建缺失的 siglip/check.py(版本检查函数)
ck = td / "models" / "siglip" / "check.py"
if not ck.exists():
ck.write_text(
"def check_whether_transformers_replace_is_installed_correctly():\n"
" return True\n"
)
print("[pi0-v3] Created siglip/check.py")

# 2. 修补 modeling_gemma.py:扩展 GemmaRMSNorm.forward 并添加 _gated_residual
gf = td / "models" / "gemma" / "modeling_gemma.py"
txt = gf.read_text()
if "_gated_residual" not in txt:
# 扩展 forward 签名
txt = txt.replace(
" def forward(self, x):\n output = self._norm(x.float())",
" def forward(self, x, cond=None):\n output = self._norm(x.float())",
1,
)
# 修改返回值为 tuple (output, None)
txt = txt.replace(
" return output.type_as(x)\n\n def extra_repr",
" return output.type_as(x), None\n\n def extra_repr",
1,
)
# 在 GemmaMLP 类前插入 _gated_residual 函数
gated = (
"\n\ndef _gated_residual(x, y, gate):\n"
" if gate is None:\n"
" return x + y\n"
" return x + y * gate\n\n"
)
txt = txt.replace("\nclass GemmaMLP", gated + "class GemmaMLP", 1)
gf.write_text(txt)
print("[pi0-v3] Patched modeling_gemma.py (GemmaRMSNorm + _gated_residual)")
else:
print("[pi0-v3] modeling_gemma.py already patched")
注意

上述补丁直接修改已安装的 transformers 文件,重新安装 transformers 会导致补丁丢失。建议在独立虚拟环境中进行,或在每次重装后重新应用。


更多参考资料