メインコンテンツまでスキップ

HDF5での学習

本ページでは、エクスポートしたHDF5データセットを用いたロボット学習(模倣学習・強化学習・VLA)の基本手順を説明します。

前処理

モデル/フレームワークに合わせてHDF5を前処理します。

基本ローダ

import h5py, io
import torch
from torch.utils.data import Dataset
from PIL import Image

class RobotDataset(Dataset):
def __init__(self, hdf5_files, transform=None):
self.hdf5_files = hdf5_files
self.transform = transform
self.episodes = []
for fp in hdf5_files:
with h5py.File(fp, 'r') as f:
for ep in f['/data'].keys():
self.episodes.append((fp, ep))
def __len__(self):
return len(self.episodes)
def __getitem__(self, idx):
fp, ep = self.episodes[idx]
with h5py.File(fp, 'r') as f:
g = f[f'/data/{ep}']
actions = g['action'][:]
states = g['observation.state'][:]
gripper = g['observation.gripper'][:]
images = {}
for k in g.keys():
if k.startswith('observation.images.'):
cam = k.split('.')[-1]
images[cam] = [Image.open(io.BytesIO(fr)) for fr in g[k][:]]
return {
'actions': torch.FloatTensor(actions),
'states': torch.FloatTensor(states),
'gripper': torch.FloatTensor(gripper),
'images': images,
'task': g.attrs.get('task', ''),
'task_zh': g.attrs.get('task_zh', ''),
'score': g.attrs.get('score', 0.0),
}

画像前処理

import torchvision.transforms as T
image_transform = T.Compose([
T.Resize((224, 224)), T.ToTensor(),
T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

模倣学習(BC)

import torch.nn as nn, torch.optim as optim
class BehaviorCloningModel(nn.Module):
def __init__(self, state_dim, action_dim, image_channels=3):
super().__init__()
self.image_encoder = nn.Sequential(
nn.Conv2d(image_channels,32,3,2,1), nn.ReLU(),
nn.Conv2d(32,64,3,2,1), nn.ReLU(),
nn.Conv2d(64,128,3,2,1), nn.ReLU(),
nn.AdaptiveAvgPool2d((4,4)), nn.Flatten(), nn.Linear(128*4*4,256)
)
self.state_encoder = nn.Sequential(nn.Linear(state_dim,128), nn.ReLU(), nn.Linear(128,128))
self.fusion = nn.Sequential(nn.Linear(384,256), nn.ReLU(), nn.Linear(256,128), nn.ReLU(), nn.Linear(128,action_dim))
def forward(self, images, states):
return self.fusion(torch.cat([self.image_encoder(images), self.state_encoder(states)], 1))

シーケンス建模(Transformer)

class TransformerPolicy(nn.Module):
def __init__(self, state_dim, action_dim, seq_len=50, d_model=256):
super().__init__()
self.d_model = d_model
self.state_proj = nn.Linear(state_dim, d_model)
self.action_proj = nn.Linear(action_dim, d_model)
self.pos = nn.Parameter(torch.randn(seq_len, d_model))
enc = nn.TransformerEncoderLayer(d_model=d_model, nhead=8, batch_first=True)
self.tr = nn.TransformerEncoder(enc, num_layers=6)
self.out = nn.Linear(d_model, action_dim)
def forward(self, states, actions=None):
x = self.state_proj(states)
if actions is not None:
a = self.action_proj(actions)
a = torch.cat([torch.zeros(a.size(0),1,self.d_model,device=a.device), a[:,:-1]], 1)
x = x + a
x = x + self.pos[:x.size(1)]
return self.out(self.tr(x))

VLA(視覚・言語・動作)

from transformers import AutoTokenizer, AutoModel
class VLAModel(nn.Module):
def __init__(self, action_dim, language_model='bert-base-uncased'):
super().__init__()
self.tok = AutoTokenizer.from_pretrained(language_model)
self.lang = AutoModel.from_pretrained(language_model)
self.vision = nn.Sequential(
nn.Conv2d(3,64,7,2,3), nn.ReLU(), nn.MaxPool2d(2),
nn.Conv2d(64,128,3,1,1), nn.ReLU(),
nn.Conv2d(128,256,3,1,1), nn.ReLU(),
nn.AdaptiveAvgPool2d((8,8)), nn.Flatten(), nn.Linear(256*8*8,512)
)
self.cross = nn.MultiheadAttention(512, 8, batch_first=True)
self.dec = nn.Sequential(nn.Linear(512,256), nn.ReLU(), nn.Linear(256,128), nn.ReLU(), nn.Linear(128,action_dim))
def forward(self, images, texts, states):
tk = self.tok(texts, return_tensors='pt', padding=True, truncation=True, max_length=128)
L = self.lang(**tk).last_hidden_state
V = self.vision(images).unsqueeze(1)
A,_ = self.cross(V, L, L)
return self.dec(A.squeeze(1))

オフラインRL(例)

import torch.nn.functional as F
class OfflineRLAgent:
def __init__(self, state_dim, action_dim, lr=3e-4):
self.actor = BehaviorCloningModel(state_dim, action_dim)
self.critic = nn.Sequential(
nn.Linear(state_dim+action_dim,256), nn.ReLU(),
nn.Linear(256,256), nn.ReLU(), nn.Linear(256,1)
)
self.actor_opt = optim.Adam(self.actor.parameters(), lr=lr)
self.critic_opt = optim.Adam(self.critic.parameters(), lr=lr)
def train_step(self, s, a, r, ns, d):
with torch.no_grad():
na = self.actor(ns)
tq = r + 0.99*(1-d)*self.critic(torch.cat([ns,na],1))
cq = self.critic(torch.cat([s,a],1))
cl = F.mse_loss(cq, tq); self.critic_opt.zero_grad(); cl.backward(); self.critic_opt.step()
pa = self.actor(s)
al = -self.critic(torch.cat([s,pa],1)).mean(); self.actor_opt.zero_grad(); al.backward(); self.actor_opt.step()
return cl.item(), al.item()

追加:データ拡張・評価

import torchvision.transforms as T
class RobotDataAugmentation:
def __init__(self):
self.image_aug = T.Compose([
T.ColorJitter(0.2,0.2,0.2,0.1), T.RandomRotation(5),
T.RandomResizedCrop(224, scale=(0.9,1.0)), T.RandomHorizontalFlip(0.1),
])

評価指標

def evaluate_model(model, dl, device):
model.eval(); total=0; n=0
with torch.no_grad():
for b in dl:
imgs = list(b['images'].values())[0][:,0].permute(0,3,1,2).to(device)
st = b['states'][:,0].to(device)
gt = b['actions'][:,0].to(device)
mse = F.mse_loss(model(imgs, st), gt)
total += mse.item()*len(gt); n += len(gt)
return total/n