Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions lerobot/common/datasets/compute_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,27 @@ def get_feature_stats(array: np.ndarray, axis: tuple, keepdims: bool) -> dict[st
def compute_episode_stats(episode_data: dict[str, list[str] | np.ndarray], features: dict) -> dict:
ep_stats = {}
for key, data in episode_data.items():
if key not in features:
continue # Skip keys that are not in features
if features[key]["dtype"] == "string":
continue # HACK: we should receive np.arrays of strings
elif features[key]["dtype"] in ["image", "video"]:
if features[key]["dtype"] in ["image", "video"]:
ep_ft_array = sample_images(data) # data is a list of image paths
axes_to_reduce = (0, 2, 3) # keep channel dim
keepdims = True
else:
ep_ft_array = data # data is already a np.ndarray
# Convert list to numpy array if needed
if isinstance(data, list):
ep_ft_array = np.stack(data)
else:
ep_ft_array = data # data is already a np.ndarray

# Convert scalar values to numpy arrays
if not isinstance(ep_ft_array, np.ndarray):
ep_ft_array = np.array([ep_ft_array])

axes_to_reduce = 0 # compute stats over the first axis
keepdims = data.ndim == 1 # keep as np.array
keepdims = ep_ft_array.ndim == 1 # keep as np.array

ep_stats[key] = get_feature_stats(ep_ft_array, axis=axes_to_reduce, keepdims=keepdims)

Expand Down
16 changes: 8 additions & 8 deletions lerobot/common/datasets/lerobot_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,10 +875,10 @@ def save_episode(self, episode_data: dict | None = None) -> None:
self._save_episode_table(episode_buffer, episode_index)
ep_stats = compute_episode_stats(episode_buffer, self.features)

if len(self.meta.video_keys) > 0:
video_paths = self.encode_episode_videos(episode_index)
for key in self.meta.video_keys:
episode_buffer[key] = video_paths[key]
# if len(self.meta.video_keys) > 0:
# video_paths = self.encode_episode_videos(episode_index)
# for key in self.meta.video_keys:
# episode_buffer[key] = video_paths[key]

# `meta.save_episode` be executed after encoding the videos
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
Expand All @@ -899,10 +899,10 @@ def save_episode(self, episode_data: dict | None = None) -> None:
parquet_files = list(self.root.rglob("*.parquet"))
assert len(parquet_files) == self.num_episodes

# delete images
img_dir = self.root / "images"
if img_dir.is_dir():
shutil.rmtree(self.root / "images")
# Comment out image deletion to preserve images
# img_dir = self.root / "images"
# if img_dir.is_dir():
# shutil.rmtree(self.root / "images")

if not episode_data: # Reset the buffer
self.episode_buffer = self.create_episode_buffer()
Expand Down
72 changes: 72 additions & 0 deletions lerobot/common/datasets/load_and_compute_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from pathlib import Path

from lerobot.common.datasets.compute_stats import compute_episode_stats
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import write_episode_stats

def main(dataset_path: str) -> None:
dataset_path = Path(dataset_path)

print(f"Dataset path: {dataset_path}")

# Using a tolerance setting which is considerably larger than the default
dataset = LeRobotDataset(f'root/{dataset_path.name}', dataset_path, tolerance_s=0.1)
dataset.load_hf_dataset()

# Get episode data index to iterate over episodes and frames
episode_data_index = dataset.episode_data_index

# Iterate over episodes
num_episodes = dataset.num_episodes
for episode_idx in range(num_episodes):
print(f"Proccessing episode: {episode_idx} / {num_episodes - 1}...")

# Create a new episode buffer with the correct episode index
dataset.episode_buffer = dataset.create_episode_buffer(episode_idx)

# Get the start and end frame indices for this episode
start_frame_idx = episode_data_index["from"][episode_idx].item()
end_frame_idx = episode_data_index["to"][episode_idx].item()

# Iterate over frames in this episode
for frame_idx in range(start_frame_idx, end_frame_idx):
print(f"\rFrame: {start_frame_idx} -> {frame_idx} / {end_frame_idx}", end="", flush=True)

# Get the frame data from the dataset
frame_data = dataset[frame_idx]
timestamp = frame_data["timestamp"]

# Create a complete frame dictionary with all necessary data
frame = {}

# Add all keys from frame_data except those that are automatically handled
for key in frame_data:
# Skip keys that are automatically handled by add_frame
if key not in ["index", "episode_index", "frame_index", "task_index"]:
frame[key] = frame_data[key]

# Make sure to use the actual timestamp
frame["timestamp"] = timestamp

# Add task if available
if "task" in frame_data:
frame["task"] = frame_data["task"]

# Add frame to the current episode buffer
dataset.add_frame(frame)

print("\nDone adding frames.")

# Compute episode stats
ep_stats = compute_episode_stats(dataset.episode_buffer, dataset.features)

# Save episode stats
# write_episode_stats(episode_idx, ep_stats, dataset_path)

if __name__ == "__main__":
import sys

if len(sys.argv) != 2:
print("Usage: python script.py <dataset_path>")
sys.exit(1)
main(sys.argv[1])
10 changes: 7 additions & 3 deletions lerobot/common/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,8 +757,10 @@ def validate_feature_numpy_array(
if actual_dtype != np.dtype(expected_dtype):
error_message += f"The feature '{name}' of dtype '{actual_dtype}' is not of the expected dtype '{expected_dtype}'.\n"

# Special case for scalar values (shape '()') when expecting shape '(1,)'
if actual_shape != expected_shape:
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n"
if not (len(expected_shape) == 1 and expected_shape[0] == 1 and actual_shape == ()):
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{expected_shape}'.\n"
else:
error_message += f"The feature '{name}' is not a 'np.ndarray'. Expected type is '{expected_dtype}', but type '{type(value)}' provided instead.\n"

Expand All @@ -770,8 +772,10 @@ def validate_feature_image_or_video(name: str, expected_shape: list[str], value:
error_message = ""
if isinstance(value, np.ndarray):
actual_shape = value.shape
c, h, w = expected_shape
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c)):
# Convert expected_shape values to integers if they are strings
c, h, w = [int(dim) if isinstance(dim, str) else dim for dim in expected_shape]
# Check if the shape is 3D and matches either (c,h,w) or (h,w,c)
if len(actual_shape) != 3 or (actual_shape != (c, h, w) and actual_shape != (h, w, c) and actual_shape != (w, c, h)):
error_message += f"The feature '{name}' of shape '{actual_shape}' does not have the expected shape '{(c, h, w)}' or '{(h, w, c)}'.\n"
elif isinstance(value, PILImage.Image):
pass
Expand Down