Обучение с HDF5
Эта страница описывает подготовку данных, имитационное обучение, последовательностное моделирование (Transformer), VLA и оффлайн‑RL на экспортированных HDF5‑наборах.
Предобработка
Адаптируйте структуру HDF5 под вашу архитектуру и фреймворк.
Базовый загрузчик
import h5py, io
import torch
from torch.utils.data import Dataset
from PIL import Image
class RobotDataset(Dataset):
def __init__(self, files, transform=None):
self.files, self.transform = files, transform
self.episodes = []
for fp in files:
with h5py.File(fp, 'r') as f:
for ep in f['/data'].keys():
self.episodes.append((fp, ep))
def __len__(self): return len(self.episodes)
def __getitem__(self, i):
fp, ep = self.episodes[i]
with h5py.File(fp, 'r') as f:
g = f[f'/data/{ep}']
actions = g['action'][:]
states = g['observation.state'][:]
gripper = g['observation.gripper'][:]
images = {}
for k in g.keys():
if k.startswith('observation.images.'):
cam = k.split('.')[-1]
images[cam] = [Image.open(io.BytesIO(fr)) for fr in g[k][:]]
return {
'actions': torch.FloatTensor(actions),
'states': torch.FloatTensor(states),
'gripper': torch.FloatTensor(gripper),
'images': images,
'task': g.attrs.get('task', ''),
'task_zh': g.attrs.get('task_zh', ''),
'score': g.attrs.get('score', 0.0),
}
Имитационное обучение (BC)
import torch.nn as nn, torch.optim as optim
class BehaviorCloningModel(nn.Module):
def __init__(self, state_dim, action_dim, image_channels=3):
super().__init__()
self.image = nn.Sequential(
nn.Conv2d(image_channels,32,3,2,1), nn.ReLU(),
nn.Conv2d(32,64,3,2,1), nn.ReLU(),
nn.Conv2d(64,128,3,2,1), nn.ReLU(),
nn.AdaptiveAvgPool2d((4,4)), nn.Flatten(), nn.Linear(128*4*4,256)
)
self.state = nn.Sequential(nn.Linear(state_dim,128), nn.ReLU(), nn.Linear(128,128))
self.fuse = nn.Sequential(nn.Linear(384,256), nn.ReLU(), nn.Linear(256,128), nn.ReLU(), nn.Linear(128,action_dim))
def forward(self, images, states):
return self.fuse(torch.cat([self.image(images), self.state(states)], 1))
Последовательности (Transformer)
class TransformerPolicy(nn.Module):
def __init__(self, state_dim, action_dim, seq_len=50, d_model=256):
super().__init__()
self.d = d_model
self.s = nn.Linear(state_dim, d_model)
self.a = nn.Linear(action_dim, d_model)
self.pos = nn.Parameter(torch.randn(seq_len, d_model))
enc = nn.TransformerEncoderLayer(d_model=d_model, nhead=8, batch_first=True)
self.tr = nn.TransformerEncoder(enc, num_layers=6)
self.out = nn.Linear(d_model, action_dim)
def forward(self, states, actions=None):
x = self.s(states)
if actions is not None:
act = self.a(actions)
act = torch.cat([torch.zeros(act.size(0),1,self.d,device=act.device), act[:,:-1]], 1)
x = x + act
x = x + self.pos[:x.size(1)]
return self.out(self.tr(x))
VLA (зрение‑язык‑действие)
from transformers import AutoTokenizer, AutoModel
class VLAModel(nn.Module):
def __init__(self, action_dim, language='bert-base-uncased'):
super().__init__()
self.tok = AutoTokenizer.from_pretrained(language)
self.lang = AutoModel.from_pretrained(language)
self.vis = nn.Sequential(
nn.Conv2d(3,64,7,2,3), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(64,128,3,1,1), nn.ReLU(),
nn.Conv2d(128,256,3,1,1), nn.ReLU(),
nn.AdaptiveAvgPool2d((8,8)), nn.Flatten(), nn.Linear(256*8*8,512)
)
self.cross = nn.MultiheadAttention(512, 8, batch_first=True)
self.dec = nn.Sequential(nn.Linear(512,256), nn.ReLU(), nn.Linear(256,128), nn.ReLU(), nn.Linear(128,action_dim))
def forward(self, images, texts, states):
tk = self.tok(texts, return_tensors='pt', padding=True, truncation=True, max_length=128)
L = self.lang(**tk).last_hidden_state
V = self.vis(images).unsqueeze(1)
A,_ = self.cross(V, L, L)
return self.dec(A.squeeze(1))
Оффлайн‑RL (пример)
import torch.nn.functional as F
class OfflineRLAgent:
def __init__(self, sd, ad, lr=3e-4):
self.actor = BehaviorCloningModel(sd, ad)
self.critic = nn.Sequential(
nn.Linear(sd+ad,256), nn.ReLU(),
nn.Linear(256,256), nn.ReLU(), nn.Linear(256,1)
)
self.aopt = optim.Adam(self.actor.parameters(), lr=lr)
self.copt = optim.Adam(self.critic.parameters(), lr=lr)
def train_step(self, s,a,r,ns,d):
with torch.no_grad():
na = self.actor(ns)
tq = r + 0.99*(1-d)*self.critic(torch.cat([ns,na],1))
cq = self.critic(torch.cat([s,a],1))
cl = F.mse_loss(cq, tq); self.copt.zero_grad(); cl.backward(); self.copt.step()
pa = self.actor(s)
al = -self.critic(torch.cat([s,pa],1)).mean(); self.aopt.zero_grad(); al.backward(); self.aopt.step()
return cl.item(), al.item()
Аугментация и оценка
import torchvision.transforms as T
class RobotDataAugmentation:
def __init__(self):
self.image_aug = T.Compose([
T.ColorJitter(0.2,0.2,0.2,0.1), T.RandomRotation(5),
T.RandomResizedCrop(224, scale=(0.9,1.0)), T.RandomHorizontalFlip(0.1),
])
Метрики
def evaluate_model(model, dl, device):
model.eval(); total=0; n=0
with torch.no_grad():
for b in dl:
imgs = list(b['images'].values())[0][:,0].permute(0,3,1,2).to(device)
st = b['states'][:,0].to(device)
gt = b['actions'][:,0].to(device)
mse = F.mse_loss(model(imgs, st), gt)
total += mse.item()*len(gt); n += len(gt)
return total/n