HDF5数据集的模型训练
本文档介绍了如何使用平台导出的HDF5数据集进行各种机器人学习模型的训练,包括模仿学习、强化学习和视觉-语言-动作(VLA)模型等。
数据预处理
在开始训练之前,通常需要对HDF5数据集进行预处理以适应不同的模型架构和训练框架。
基础数据加载器
import h5py
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import io
class RobotDataset(Dataset):
def __init__(self, hdf5_files, transform=None):
self.hdf5_files = hdf5_files
self.transform = transform
self.episodes = []
# 索引所有episode
for file_path in hdf5_files:
with h5py.File(file_path, 'r') as f:
data_group = f['/data']
for episode_name in data_group.keys():
self.episodes.append((file_path, episode_name))
def __len__(self):
return len(self.episodes)
def __getitem__(self, idx):
file_path, episode_name = self.episodes[idx]
with h5py.File(file_path, 'r') as f:
episode = f[f'/data/{episode_name}']
# 读取动作数据
actions = episode['action'][:]
# 读取状态数据
states = episode['observation.state'][:]
# 读取夹爪状态
gripper = episode['observation.gripper'][:]
# 读取图像数据
images = {}
for key in episode.keys():
if key.startswith('observation.images.'):
camera_name = key.split('.')[-1]
# 解压JPEG图像
img_data = episode[key][:]
images[camera_name] = [Image.open(io.BytesIO(frame)) for frame in img_data]
# 读取任务描述
task = episode.attrs.get('task', '')
task_zh = episode.attrs.get('task_zh', '')
score = episode.attrs.get('score', 0.0)
return {
'actions': torch.FloatTensor(actions),
'states': torch.FloatTensor(states),
'gripper': torch.FloatTensor(gripper),
'images': images,
'task': task,
'task_zh': task_zh,
'score': score
}