Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,7 @@ build*/
books/*

# Spsa output for script usage
spsa_out*.txt
spsa_out*.txt

# Venv folder for python scripts
venv/
176 changes: 92 additions & 84 deletions scripts/extract-data.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
import os
import random
import io
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Optional
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import List
import chess
import chess.pgn
import io
from tqdm import tqdm


# ------------------------------------------------------------
# Game logic
# ------------------------------------------------------------

def is_capture_or_promotion(board: chess.Board, move: chess.Move) -> bool:
"""Check if a move is a capture or promotion."""
return board.is_capture(move) or move.promotion is not None


def extract_positions_from_game(game_data: str, n_positions: int = 5) -> List[str]:
"""Extract N sampled positions from a game, with original restrictions, labeled with final result (w/d/b)."""
try:
pgn_io = io.StringIO(game_data)
game = chess.pgn.read_game(pgn_io)
if not game:
return []

# Map result to label
result = game.headers.get("Result", "")
if result == "1-0":
outcome = "w"
Expand All @@ -32,27 +31,22 @@ def extract_positions_from_game(game_data: str, n_positions: int = 5) -> List[st
elif result == "1/2-1/2":
outcome = "d"
else:
return [] # Skip unfinished/invalid games
return []

board = game.board()
valid_positions = []
move_number = 0

# Iterate through all moves in the game
for node in game.mainline():
move = node.move
if move:
board.push(move)
move_number += 1

# Skip if in check or no legal moves
if board.is_check() or not any(board.legal_moves):
continue

next_node = node.next()
best_move = next_node.move if next_node else None

# Skip if best move is capture or promotion
if best_move and is_capture_or_promotion(board, best_move):
continue

Expand All @@ -61,7 +55,6 @@ def extract_positions_from_game(game_data: str, n_positions: int = 5) -> List[st
if not valid_positions:
return []

# Randomly select up to n_positions
if len(valid_positions) > n_positions:
valid_positions = random.sample(valid_positions, n_positions)

Expand All @@ -71,115 +64,130 @@ def extract_positions_from_game(game_data: str, n_positions: int = 5) -> List[st
return []


def count_games_in_file(file_path: Path) -> int:
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
return content.count("[Event ")
except Exception:
return 0
def process_single_game(game_data: str, positions_per_game: int) -> List[str]:
return extract_positions_from_game(game_data, positions_per_game)


def process_pgn_file(file_path: Path, positions_per_game: int = 5, pbar: Optional[tqdm] = None) -> List[str]:
results = []
try:
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()

games = []
current_game = []
for line in content.split("\n"):
if line.strip().startswith("[Event ") and current_game:
games.append("\n".join(current_game))
current_game = [line]
# ------------------------------------------------------------
# PGN streaming utilities
# ------------------------------------------------------------

def stream_pgn_games(file_path: Path):
with open(file_path, "r", encoding="utf-8") as f:
game_lines = []
for line in f:
if line.startswith("[Event ") and game_lines:
yield "".join(game_lines)
game_lines = [line]
else:
current_game.append(line)
if current_game:
games.append("\n".join(current_game))

for game_data in games:
if game_data.strip():
positions = extract_positions_from_game(game_data, positions_per_game)
results.extend(positions)
if pbar:
pbar.update(1)
game_lines.append(line)

except Exception as e:
print(f"Error processing file {file_path}: {e}")
if pbar:
expected_games = count_games_in_file(file_path)
pbar.update(expected_games)
if game_lines:
yield "".join(game_lines)

return results

def count_games_in_file(file_path: Path) -> int:
count = 0
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
if line.startswith("[Event "):
count += 1
return count


def extract_positions_from_folder(folder_path: str, positions_per_game: int = 5, max_workers: int = 4) -> List[str]:
# ------------------------------------------------------------
# Main multiprocessing pipeline
# ------------------------------------------------------------

def extract_positions_from_folder(
folder_path: str,
positions_per_game: int = 5,
max_workers: int = 4,
task_buffer: int = 2000,
) -> List[str]:

folder = Path(folder_path)
if not folder.exists():
raise FileNotFoundError(f"Folder {folder_path} does not exist")
raise FileNotFoundError(f"Folder does not exist: {folder}")

pgn_files = list(folder.glob("*.pgn"))
if not pgn_files:
print(f"No PGN files found in {folder_path}")
print("No PGN files found.")
return []

print(f"Found {len(pgn_files)} PGN files")

total_games = 0
for pgn_file in tqdm(pgn_files, desc="Counting games"):
total_games += count_games_in_file(pgn_file)

total_games = sum(count_games_in_file(p) for p in pgn_files)
print(f"Total games to process: {total_games}")

all_results = []
with tqdm(total=total_games, desc="Processing games", unit="games") as pbar:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_file = {
executor.submit(process_pgn_file, pgn_file, positions_per_game, pbar): pgn_file
for pgn_file in pgn_files
}
for future in as_completed(future_to_file):
try:
all_results.extend(future.result())
except Exception as e:
print(f"Error: {e}")
futures = []

with ProcessPoolExecutor(max_workers=max_workers) as executor:
with tqdm(total=total_games, desc="Processing games", unit="games") as pbar:

for pgn_file in pgn_files:
for game_data in stream_pgn_games(pgn_file):
futures.append(
executor.submit(
process_single_game,
game_data,
positions_per_game,
)
)

if len(futures) >= task_buffer:
for future in as_completed(futures):
all_results.extend(future.result())
pbar.update(1)
futures.clear()

for future in as_completed(futures):
all_results.extend(future.result())
pbar.update(1)

return all_results


def save_positions_to_file(positions: List[str], output_file: str = "sampled_positions.txt"):
# ------------------------------------------------------------
# Output
# ------------------------------------------------------------

def save_positions_to_file(positions: List[str], output_file: str):
with open(output_file, "w", encoding="utf-8") as f:
for line in positions:
f.write(line + "\n")
print(f"Saved {len(positions)} positions to {output_file}")


# ------------------------------------------------------------
# CLI
# ------------------------------------------------------------

def main():
folder_path = input("Enter the path to the folder containing PGN files: ").strip()

try:
positions_per_game = int(input("Enter number of positions per game (default 5): ") or "5")
except ValueError:
positions_per_game = 5

try:
max_workers = int(input("Enter maximum number of worker threads (default 4): ") or "4")
max_workers = int(input("Enter maximum number of worker processes (default 4): ") or "4")
except ValueError:
max_workers = 4

try:
positions = extract_positions_from_folder(
folder_path, positions_per_game=positions_per_game, max_workers=max_workers
)
if positions:
output_file = f"sampled_positions_{len(positions)}.txt"
save_positions_to_file(positions, output_file)
print(f"\nSummary:")
print(f"- Total positions saved: {len(positions)}")
print(f"- File: {output_file}")
else:
print("No valid games found!")
except Exception as e:
print(f"Error: {e}")
positions = extract_positions_from_folder(
folder_path,
positions_per_game=positions_per_game,
max_workers=max_workers,
)

if positions:
output_file = f"sampled_positions_{len(positions)}.txt"
save_positions_to_file(positions, output_file)
else:
print("No positions extracted.")


if __name__ == "__main__":
Expand All @@ -188,7 +196,7 @@ def main():
import chess.pgn
from tqdm import tqdm
except ImportError:
print("Please install required dependencies:")
print("Please install dependencies:")
print("pip install python-chess tqdm")
exit(1)

Expand Down
Loading
Loading