Skip to main content

Training with HDF5

This page explains how to train robot learning models using exported HDF5 datasets: imitation learning, RL, and VLA.

Preprocessing

Prepare data to match your model/framework.

Basic loader

import h5py
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import io

class RobotDataset(Dataset):
def __init__(self, hdf5_files, transform=None):
self.hdf5_files = hdf5_files
self.transform = transform
self.episodes = []
for file_path in hdf5_files:
with h5py.File(file_path, 'r') as f:
data_group = f['/data']
for episode_name in data_group.keys():
self.episodes.append((file_path, episode_name))
def __len__(self):
return len(self.episodes)
def __getitem__(self, idx):
file_path, episode_name = self.episodes[idx]
with h5py.File(file_path, 'r') as f:
episode = f[f'/data/{episode_name}']
actions = episode['action'][:]
states = episode['observation.state'][:]
gripper = episode['observation.gripper'][:]
images = {}
for key in episode.keys():
if key.startswith('observation.images.'):
cam = key.split('.')[-1]
img_data = episode[key][:]
images[cam] = [Image.open(io.BytesIO(frame)) for frame in img_data]
task = episode.attrs.get('task', '')
task_zh = episode.attrs.get('task_zh', '')
score = episode.attrs.get('score', 0.0)
return {
'actions': torch.FloatTensor(actions),
'states': torch.FloatTensor(states),
'gripper': torch.FloatTensor(gripper),
'images': images,
'task': task,
'task_zh': task_zh,
'score': score,
}

Image transforms

import torchvision.transforms as transforms

image_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def preprocess_images(images_dict, transform):
processed = {}
for cam, imgs in images_dict.items():
processed[cam] = torch.stack([transform(img) for img in imgs])
return processed

Imitation learning

Behavior cloning

import torch.nn as nn
import torch.optim as optim

class BehaviorCloningModel(nn.Module):
def __init__(self, state_dim, action_dim, image_channels=3):
super().__init__()
self.image_encoder = nn.Sequential(
nn.Conv2d(image_channels, 32, 3, stride=2, padding=1), nn.ReLU(),
nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(),
nn.Conv2d(64, 128, 3, stride=2, padding=1), nn.ReLU(),
nn.AdaptiveAvgPool2d((4, 4)), nn.Flatten(), nn.Linear(128*4*4, 256)
)
self.state_encoder = nn.Sequential(nn.Linear(state_dim, 128), nn.ReLU(), nn.Linear(128, 128))
self.fusion = nn.Sequential(nn.Linear(256+128, 256), nn.ReLU(), nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, action_dim))
def forward(self, images, states):
img_feat = self.image_encoder(images)
st_feat = self.state_encoder(states)
return self.fusion(torch.cat([img_feat, st_feat], dim=1))

def train_bc_model(model, dataloader, num_epochs=100):
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(num_epochs):
total = 0
for batch in dataloader:
images = list(batch['images'].values())[0][:, 0].permute(0, 3, 1, 2)
states = batch['states'][:, 0]
actions = batch['actions'][:, 0]
optimizer.zero_grad()
loss = criterion(model(images, states), actions)
loss.backward(); optimizer.step()
total += loss.item()
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {total/len(dataloader):.4f}')

Sequence modeling (Transformer)

class TransformerPolicy(nn.Module):
def __init__(self, state_dim, action_dim, seq_len=50, d_model=256):
super().__init__()
self.seq_len, self.d_model = seq_len, d_model
self.state_proj = nn.Linear(state_dim, d_model)
self.action_proj = nn.Linear(action_dim, d_model)
self.pos_encoding = nn.Parameter(torch.randn(seq_len, d_model))
enc_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=8, batch_first=True)
self.transformer = nn.TransformerEncoder(enc_layer, num_layers=6)
self.output_proj = nn.Linear(d_model, action_dim)
def forward(self, states, actions=None):
bsz, seq_len = states.shape[:2]
st = self.state_proj(states)
if actions is not None:
act = self.action_proj(actions)
act_in = torch.cat([torch.zeros(bsz, 1, self.d_model, device=actions.device), act[:, :-1]], dim=1)
inputs = st + act_in
else:
inputs = st
inputs += self.pos_encoding[:seq_len]
out = self.transformer(inputs)
return self.output_proj(out)

VLA model (vision‑language‑action)

from transformers import AutoTokenizer, AutoModel

class VLAModel(nn.Module):
def __init__(self, action_dim, language_model='bert-base-uncased'):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(language_model)
self.language_encoder = AutoModel.from_pretrained(language_model)
self.vision_encoder = nn.Sequential(
nn.Conv2d(3, 64, 7, stride=2, padding=3), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(),
nn.AdaptiveAvgPool2d((8, 8)), nn.Flatten(), nn.Linear(256*8*8, 512)
)
self.cross_attention = nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)
self.action_decoder = nn.Sequential(nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, 128), nn.ReLU(), nn.Linear(128, action_dim))
def forward(self, images, task_descriptions, states):
tokens = self.tokenizer(task_descriptions, return_tensors='pt', padding=True, truncation=True, max_length=128)
lang = self.language_encoder(**tokens).last_hidden_state
vis = self.vision_encoder(images).unsqueeze(1)
attn, _ = self.cross_attention(vis, lang, lang)
return self.action_decoder(attn.squeeze(1))

Offline RL example

import torch.nn.functional as F

class OfflineRLAgent:
def __init__(self, state_dim, action_dim, lr=3e-4):
self.actor = BehaviorCloningModel(state_dim, action_dim)
self.critic = nn.Sequential(
nn.Linear(state_dim + action_dim, 256), nn.ReLU(),
nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, 1)
)
self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)
self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=lr)
def train_step(self, states, actions, rewards, next_states, dones):
with torch.no_grad():
next_actions = self.actor(next_states)
target_q = rewards + 0.99 * (1 - dones) * self.critic(torch.cat([next_states, next_actions], dim=1))
current_q = self.critic(torch.cat([states, actions], dim=1))
critic_loss = F.mse_loss(current_q, target_q)
self.critic_optimizer.zero_grad(); critic_loss.backward(); self.critic_optimizer.step()
pred = self.actor(states)
actor_loss = -self.critic(torch.cat([states, pred], dim=1)).mean()
self.actor_optimizer.zero_grad(); actor_loss.backward(); self.actor_optimizer.step()
return critic_loss.item(), actor_loss.item()

Augmentation & evaluation

import torchvision.transforms as T

class RobotDataAugmentation:
def __init__(self):
self.image_aug = T.Compose([
T.ColorJitter(0.2, 0.2, 0.2, 0.1), T.RandomRotation(5),
T.RandomResizedCrop(224, scale=(0.9, 1.0)), T.RandomHorizontalFlip(0.1),
])
def augment_episode(self, ep):
out = ep.copy()
for cam, imgs in ep['images'].items():
out['images'][cam] = [self.image_aug(img) if torch.rand(1) < 0.5 else img for img in imgs]
if torch.rand(1) < 0.3:
out['actions'] = ep['actions'] + torch.randn_like(ep['actions']) * 0.01
return out

Metrics

def evaluate_model(model, test_dataloader, device):
model.eval(); total, n = 0, 0
with torch.no_grad():
for batch in test_dataloader:
images = list(batch['images'].values())[0][:, 0].permute(0, 3, 1, 2).to(device)
states = batch['states'][:, 0].to(device)
actions = batch['actions'][:, 0].to(device)
mse = F.mse_loss(model(images, states), actions)
total += mse.item() * len(actions); n += len(actions)
return total / n

Summary

  • Imitation: learn from expert demos
  • Sequence: model temporal dependencies (Transformer)
  • Multimodal: combine vision, language, actions (VLA)
  • RL: offline pretraining or replay initialization