diff --git a/.gitignore b/.gitignore index 6f8887db..58377e54 100644 --- a/.gitignore +++ b/.gitignore @@ -46,4 +46,7 @@ build*/ books/* # Spsa output for script usage -spsa_out*.txt \ No newline at end of file +spsa_out*.txt + +# Venv folder for python scripts +venv/ \ No newline at end of file diff --git a/scripts/extract-data.py b/scripts/extract-data.py index a032ff92..31b90a48 100644 --- a/scripts/extract-data.py +++ b/scripts/extract-data.py @@ -1,29 +1,28 @@ -import os import random +import io from pathlib import Path -from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import List, Optional +from concurrent.futures import ProcessPoolExecutor, as_completed +from typing import List import chess import chess.pgn -import io from tqdm import tqdm +# ------------------------------------------------------------ +# Game logic +# ------------------------------------------------------------ def is_capture_or_promotion(board: chess.Board, move: chess.Move) -> bool: - """Check if a move is a capture or promotion.""" return board.is_capture(move) or move.promotion is not None def extract_positions_from_game(game_data: str, n_positions: int = 5) -> List[str]: - """Extract N sampled positions from a game, with original restrictions, labeled with final result (w/d/b).""" try: pgn_io = io.StringIO(game_data) game = chess.pgn.read_game(pgn_io) if not game: return [] - # Map result to label result = game.headers.get("Result", "") if result == "1-0": outcome = "w" @@ -32,27 +31,22 @@ def extract_positions_from_game(game_data: str, n_positions: int = 5) -> List[st elif result == "1/2-1/2": outcome = "d" else: - return [] # Skip unfinished/invalid games + return [] board = game.board() valid_positions = [] - move_number = 0 - # Iterate through all moves in the game for node in game.mainline(): move = node.move if move: board.push(move) - move_number += 1 - # Skip if in check or no legal moves if board.is_check() or not any(board.legal_moves): continue next_node = node.next() best_move = next_node.move if next_node else None - # Skip if best move is capture or promotion if best_move and is_capture_or_promotion(board, best_move): continue @@ -61,7 +55,6 @@ def extract_positions_from_game(game_data: str, n_positions: int = 5) -> List[st if not valid_positions: return [] - # Randomly select up to n_positions if len(valid_positions) > n_positions: valid_positions = random.sample(valid_positions, n_positions) @@ -71,115 +64,130 @@ def extract_positions_from_game(game_data: str, n_positions: int = 5) -> List[st return [] -def count_games_in_file(file_path: Path) -> int: - try: - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - return content.count("[Event ") - except Exception: - return 0 +def process_single_game(game_data: str, positions_per_game: int) -> List[str]: + return extract_positions_from_game(game_data, positions_per_game) -def process_pgn_file(file_path: Path, positions_per_game: int = 5, pbar: Optional[tqdm] = None) -> List[str]: - results = [] - try: - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - - games = [] - current_game = [] - for line in content.split("\n"): - if line.strip().startswith("[Event ") and current_game: - games.append("\n".join(current_game)) - current_game = [line] +# ------------------------------------------------------------ +# PGN streaming utilities +# ------------------------------------------------------------ + +def stream_pgn_games(file_path: Path): + with open(file_path, "r", encoding="utf-8") as f: + game_lines = [] + for line in f: + if line.startswith("[Event ") and game_lines: + yield "".join(game_lines) + game_lines = [line] else: - current_game.append(line) - if current_game: - games.append("\n".join(current_game)) - - for game_data in games: - if game_data.strip(): - positions = extract_positions_from_game(game_data, positions_per_game) - results.extend(positions) - if pbar: - pbar.update(1) + game_lines.append(line) - except Exception as e: - print(f"Error processing file {file_path}: {e}") - if pbar: - expected_games = count_games_in_file(file_path) - pbar.update(expected_games) + if game_lines: + yield "".join(game_lines) - return results + +def count_games_in_file(file_path: Path) -> int: + count = 0 + with open(file_path, "r", encoding="utf-8") as f: + for line in f: + if line.startswith("[Event "): + count += 1 + return count -def extract_positions_from_folder(folder_path: str, positions_per_game: int = 5, max_workers: int = 4) -> List[str]: +# ------------------------------------------------------------ +# Main multiprocessing pipeline +# ------------------------------------------------------------ + +def extract_positions_from_folder( + folder_path: str, + positions_per_game: int = 5, + max_workers: int = 4, + task_buffer: int = 2000, +) -> List[str]: + folder = Path(folder_path) if not folder.exists(): - raise FileNotFoundError(f"Folder {folder_path} does not exist") + raise FileNotFoundError(f"Folder does not exist: {folder}") pgn_files = list(folder.glob("*.pgn")) if not pgn_files: - print(f"No PGN files found in {folder_path}") + print("No PGN files found.") return [] print(f"Found {len(pgn_files)} PGN files") - total_games = 0 - for pgn_file in tqdm(pgn_files, desc="Counting games"): - total_games += count_games_in_file(pgn_file) - + total_games = sum(count_games_in_file(p) for p in pgn_files) print(f"Total games to process: {total_games}") all_results = [] - with tqdm(total=total_games, desc="Processing games", unit="games") as pbar: - with ThreadPoolExecutor(max_workers=max_workers) as executor: - future_to_file = { - executor.submit(process_pgn_file, pgn_file, positions_per_game, pbar): pgn_file - for pgn_file in pgn_files - } - for future in as_completed(future_to_file): - try: - all_results.extend(future.result()) - except Exception as e: - print(f"Error: {e}") + futures = [] + + with ProcessPoolExecutor(max_workers=max_workers) as executor: + with tqdm(total=total_games, desc="Processing games", unit="games") as pbar: + + for pgn_file in pgn_files: + for game_data in stream_pgn_games(pgn_file): + futures.append( + executor.submit( + process_single_game, + game_data, + positions_per_game, + ) + ) + + if len(futures) >= task_buffer: + for future in as_completed(futures): + all_results.extend(future.result()) + pbar.update(1) + futures.clear() + + for future in as_completed(futures): + all_results.extend(future.result()) + pbar.update(1) return all_results -def save_positions_to_file(positions: List[str], output_file: str = "sampled_positions.txt"): +# ------------------------------------------------------------ +# Output +# ------------------------------------------------------------ + +def save_positions_to_file(positions: List[str], output_file: str): with open(output_file, "w", encoding="utf-8") as f: for line in positions: f.write(line + "\n") print(f"Saved {len(positions)} positions to {output_file}") +# ------------------------------------------------------------ +# CLI +# ------------------------------------------------------------ + def main(): folder_path = input("Enter the path to the folder containing PGN files: ").strip() + try: positions_per_game = int(input("Enter number of positions per game (default 5): ") or "5") except ValueError: positions_per_game = 5 try: - max_workers = int(input("Enter maximum number of worker threads (default 4): ") or "4") + max_workers = int(input("Enter maximum number of worker processes (default 4): ") or "4") except ValueError: max_workers = 4 - try: - positions = extract_positions_from_folder( - folder_path, positions_per_game=positions_per_game, max_workers=max_workers - ) - if positions: - output_file = f"sampled_positions_{len(positions)}.txt" - save_positions_to_file(positions, output_file) - print(f"\nSummary:") - print(f"- Total positions saved: {len(positions)}") - print(f"- File: {output_file}") - else: - print("No valid games found!") - except Exception as e: - print(f"Error: {e}") + positions = extract_positions_from_folder( + folder_path, + positions_per_game=positions_per_game, + max_workers=max_workers, + ) + + if positions: + output_file = f"sampled_positions_{len(positions)}.txt" + save_positions_to_file(positions, output_file) + else: + print("No positions extracted.") if __name__ == "__main__": @@ -188,7 +196,7 @@ def main(): import chess.pgn from tqdm import tqdm except ImportError: - print("Please install required dependencies:") + print("Please install dependencies:") print("pip install python-chess tqdm") exit(1) diff --git a/src/eval_constants.hpp b/src/eval_constants.hpp index dcc99089..f588247d 100644 --- a/src/eval_constants.hpp +++ b/src/eval_constants.hpp @@ -5,170 +5,172 @@ namespace Clockwork { // clang-format off -inline const PParam PAWN_MAT = S(287, 340); -inline const PParam KNIGHT_MAT = S(917, 685); -inline const PParam BISHOP_MAT = S(868, 596); -inline const PParam ROOK_MAT = S(1292, 1272); -inline const PParam QUEEN_MAT = S(2710, 1855); -inline const PParam TEMPO_VAL = S(59, 15); +inline const PParam PAWN_MAT = S(301, 307); +inline const PParam KNIGHT_MAT = S(916, 615); +inline const PParam BISHOP_MAT = S(835, 526); +inline const PParam ROOK_MAT = S(1359, 1179); +inline const PParam QUEEN_MAT = S(2765, 1705); +inline const PParam TEMPO_VAL = S(65, 13); -inline const PParam BISHOP_PAIR_VAL = S(80, 177); -inline const PParam ROOK_OPEN_VAL = S(101, -27); -inline const PParam ROOK_SEMIOPEN_VAL = S(36, 15); +inline const PParam BISHOP_PAIR_VAL = S(84, 168); +inline const PParam ROOK_OPEN_VAL = S(111, -29); +inline const PParam ROOK_SEMIOPEN_VAL = S(40, 17); -inline const PParam DOUBLED_PAWN_VAL = S(-36, -78); +inline const PParam DOUBLED_PAWN_VAL = S(-40, -75); -inline const PParam POTENTIAL_CHECKER_VAL = S(-66, -8); -inline const PParam OUTPOST_KNIGHT_VAL = S(53, 32); -inline const PParam OUTPOST_BISHOP_VAL = S(62, 22); +inline const PParam POTENTIAL_CHECKER_VAL = S(-67, -10); +inline const PParam OUTPOST_KNIGHT_VAL = S(53, 36); +inline const PParam OUTPOST_BISHOP_VAL = S(61, 28); -inline const PParam PAWN_PUSH_THREAT_KNIGHT = S(48, 7); -inline const PParam PAWN_PUSH_THREAT_BISHOP = S(55, -22); -inline const PParam PAWN_PUSH_THREAT_ROOK = S(34, 31); -inline const PParam PAWN_PUSH_THREAT_QUEEN = S(60, -50); +inline const PParam PAWN_PUSH_THREAT_KNIGHT = S(52, 3); +inline const PParam PAWN_PUSH_THREAT_BISHOP = S(59, -26); +inline const PParam PAWN_PUSH_THREAT_ROOK = S(34, 34); +inline const PParam PAWN_PUSH_THREAT_QUEEN = S(69, -58); inline const std::array PAWN_PHALANX = { - S(21, 18), S(62, 30), S(76, 70), S(186, 145), S(538, 261), S(940, 1138) + S(26, 15), S(62, 32), S(77, 69), S(188, 146), S(497, 246), S(721, 709) }; inline const std::array DEFENDED_PAWN = { - S(66, 40), S(60, 31), S(67, 58), S(150, 121), S(689, -86) + S(73, 38), S(64, 33), S(75, 66), S(178, 115), S(632, -29) }; inline const std::array PASSED_PAWN = { - S(-77, -92), S(-77, -69), S(-51, 5), S(25, 75), S(89, 196), S(291, 304) + S(-84, -98), S(-82, -75), S(-57, 5), S(24, 73), S(102, 197), S(316, 281) }; inline const std::array DEFENDED_PASSED_PUSH = { - S(50, -44), S(33, -3), S(20, 28), S(27, 71), S(106, 142), S(152, 290) + S(44, -36), S(33, -1), S(27, 24), S(20, 77), S(112, 126), S(217, 224) }; inline const std::array BLOCKED_PASSED_PAWN = { - S(13, -43), S(-1, 6), S(-3, -22), S(4, -44), S(-2, -89), S(-212, -130) + S(5, -33), S(-1, 5), S(-6, -20), S(-7, -41), S(-22, -87), S(-265, -111) }; inline const std::array FRIENDLY_KING_PASSED_PAWN_DISTANCE = { - S(0, 0), S(-3, 114), S(-14, 85), S(-7, 32), S(2, 4), S(11, 9), S(38, 8), S(17, -5) + S(0, 0), S(1, 113), S(-9, 88), S(-3, 33), S(-1, 8), S(11, 9), S(47, 7), S(21, -9) }; inline const std::array ENEMY_KING_PASSED_PAWN_DISTANCE = { - S(0, 0), S(-273, -17), S(-34, 19), S(-14, 45), S(29, 55), S(33, 82), S(50, 84), S(3, 102) + S(0, 0), S(-280, -12), S(-32, 25), S(-10, 48), S(32, 62), S(38, 84), S(51, 90), S(17, 92) }; inline const std::array KNIGHT_MOBILITY = { - S(-233, -223), S(-125, -63), S(-67, -10), S(-24, 24), S(23, 38), S(50, 76), S(88, 71), S(122, 73), - S(168, 16) + S(-234, -218), S(-125, -59), S(-69, -1), S(-29, 31), S(18, 40), S(47, 75), S(88, 65), S(126, 65), + S(175, 4) }; inline const std::array BISHOP_MOBILITY = { - S(-245, -299), S(-172, -116), S(-101, -51), S(-68, 0), S(-39, 33), S(-23, 55), S(-6, 71), S(12, 77), - S(30, 83), S(42, 80), S(66, 69), S(131, 22), S(158, 1), S(219, -32) + S(-251, -289), S(-178, -101), S(-107, -43), S(-71, 0), S(-44, 36), S(-21, 55), S(-7, 69), S(14, 73), + S(33, 81), S(52, 72), S(76, 59), S(137, 16), S(147, 12), S(216, -36) }; inline const std::array ROOK_MOBILITY = { - S(-300, -251), S(-148, -97), S(-96, -33), S(-63, -24), S(-38, 0), S(-24, 22), S(-6, 34), S(12, 40), - S(29, 53), S(46, 62), S(63, 65), S(74, 68), S(94, 71), S(106, 57), S(251, -70) + S(-279, -358), S(-165, -63), S(-108, -18), S(-70, -13), S(-43, 8), S(-28, 34), S(-10, 48), S(8, 51), + S(25, 61), S(41, 69), S(61, 69), S(76, 72), S(91, 74), S(119, 53), S(277, -87) }; inline const std::array QUEEN_MOBILITY = { - S(-790, -500), S(-307, -661), S(-217, -505), S(-165, -301), S(-158, -95), S(-122, 11), S(-119, 121), S(-95, 133), - S(-91, 189), S(-79, 213), S(-70, 238), S(-65, 253), S(-47, 244), S(-37, 255), S(-31, 249), S(-19, 244), - S(-12, 234), S(-13, 240), S(12, 196), S(34, 157), S(51, 137), S(96, 69), S(108, 61), S(272, -118), - S(308, -162), S(544, -315), S(373, -232), S(625, -368) + S(-833, -482), S(-295, -631), S(-193, -534), S(-140, -337), S(-131, -131), S(-100, 7), S(-95, 108), S(-69, 129), + S(-61, 178), S(-50, 200), S(-39, 229), S(-30, 232), S(-9, 219), S(-7, 238), S(-1, 236), S(11, 226), + S(26, 206), S(23, 214), S(39, 181), S(68, 143), S(80, 123), S(134, 49), S(144, 42), S(258, -89), + S(356, -174), S(439, -263), S(125, -86), S(343, -233) }; inline const std::array KING_MOBILITY = { - S(561, -76), S(168, -162), S(37, -56), S(0, -6), S(-50, 11), S(-112, 36), S(-159, 77), S(-207, 99), - S(-242, 77) + S(516, -126), S(185, -152), S(51, -46), S(6, 2), S(-45, 15), S(-106, 40), S(-155, 78), S(-206, 100), + S(-246, 85) }; inline const std::array KNIGHT_KING_RING = { - S(206, 236), S(309, 188), S(383, 128) + S(269, 228), S(357, 199), S(424, 158) }; inline const std::array BISHOP_KING_RING = { - S(378, 386), S(229, 246), S(144, 77) + S(478, 377), S(295, 239), S(164, 74) }; inline const std::array ROOK_KING_RING = { - S(414, 435), S(515, 416), S(539, 428), S(620, 462), S(695, 433) + S(425, 395), S(536, 393), S(560, 409), S(645, 461), S(767, 434) }; inline const std::array QUEEN_KING_RING = { - S(1034, 998), S(716, 810), S(422, 616), S(204, 352), S(104, 38), S(26, -329) + S(1124, 911), S(777, 741), S(457, 560), S(201, 319), S(81, 12), S(-51, -326) }; inline const std::array PT_INNER_RING_ATTACKS = { - S(-110, 63), S(9, -18), S(-183, -134), S(34, 31), S(-279, -205) + S(-113, 60), S(-1, 0), S(-224, -129), S(42, 43), S(-307, -189) }; inline const std::array PT_OUTER_RING_ATTACKS = { - S(-27, 23), S(-21, 18), S(-25, 13), S(-15, 9), S(-22, -8) + S(-26, 23), S(-23, 21), S(-23, 15), S(-15, 9), S(-21, -11) }; -inline const PParam PAWN_THREAT_KNIGHT = S(234, 62); -inline const PParam PAWN_THREAT_BISHOP = S(204, 104); -inline const PParam PAWN_THREAT_ROOK = S(199, 56); -inline const PParam PAWN_THREAT_QUEEN = S(175, -53); +inline const PParam PAWN_THREAT_KNIGHT = S(247, 58); +inline const PParam PAWN_THREAT_BISHOP = S(232, 122); +inline const PParam PAWN_THREAT_ROOK = S(209, 84); +inline const PParam PAWN_THREAT_QUEEN = S(189, -39); -inline const PParam KNIGHT_THREAT_BISHOP = S(106, 73); -inline const PParam KNIGHT_THREAT_ROOK = S(246, 4); -inline const PParam KNIGHT_THREAT_QUEEN = S(161, -70); +inline const PParam KNIGHT_THREAT_BISHOP = S(121, 69); +inline const PParam KNIGHT_THREAT_ROOK = S(257, 13); +inline const PParam KNIGHT_THREAT_QUEEN = S(166, -56); -inline const PParam BISHOP_THREAT_KNIGHT = S(112, 34); -inline const PParam BISHOP_THREAT_ROOK = S(245, 56); -inline const PParam BISHOP_THREAT_QUEEN = S(192, 54); +inline const PParam BISHOP_THREAT_KNIGHT = S(121, 38); +inline const PParam BISHOP_THREAT_ROOK = S(248, 66); +inline const PParam BISHOP_THREAT_QUEEN = S(199, 51); inline const std::array BISHOP_PAWNS = { - S(0, -12), S(-3, -6), S(-3, -16), S(-8, -27), S(-14, -32), S(-19, -40), S(-20, -47), S(-26, -42), - S(-34, -48) + S(1, -17), S(-4, -5), S(-4, -16), S(-9, -23), S(-16, -29), S(-22, -34), S(-25, -43), S(-29, -48), + S(-39, -52) }; inline const std::array PAWN_PSQT = { - S(133, 152), S(99, 213), S(182, 175), S(246, 56), S(195, 49), S(186, 110), S(74, 134), S(153, 99), - S(21, 52), S(113, 90), S(92, 36), S(98, -23), S(61, -44), S(11, 1), S(-24, 43), S(-60, 40), - S(-23, 1), S(-15, 13), S(-2, -27), S(-7, -44), S(-18, -51), S(-54, -49), S(-93, -3), S(-111, 13), - S(-29, -59), S(-17, -24), S(-19, -57), S(-35, -55), S(-57, -65), S(-75, -58), S(-125, -9), S(-143, -22), - S(-33, -87), S(26, -78), S(-25, -31), S(-49, -33), S(-69, -43), S(-105, -47), S(-124, -33), S(-146, -40), - S(-34, -76), S(90, -66), S(42, -27), S(-6, -14), S(-38, -30), S(-68, -40), S(-93, -15), S(-127, -28) + S(185, 136), S(126, 199), S(191, 162), S(166, 103), S(218, 36), S(177, 93), S(93, 128), S(168, 98), + S(36, 29), S(109, 75), S(87, 19), S(52, -11), S(45, -44), S(12, -8), S(-10, 24), S(-55, 34), + S(-16, -8), S(-11, 10), S(4, -33), S(-9, -42), S(-23, -49), S(-56, -39), S(-91, -4), S(-104, 9), + S(-31, -56), S(-15, -26), S(-12, -49), S(-32, -48), S(-59, -54), S(-76, -47), S(-123, -9), S(-143, -18), + S(-38, -81), S(25, -74), S(-29, -29), S(-42, -31), S(-75, -33), S(-107, -36), S(-127, -24), S(-145, -36), + S(-37, -72), S(88, -65), S(37, -26), S(-5, -10), S(-41, -20), S(-70, -28), S(-91, -12), S(-128, -18) }; inline const std::array KNIGHT_PSQT = { - S(-396, -157), S(-342, 56), S(-455, 232), S(-124, 68), S(-256, 96), S(-323, 92), S(-565, 82), S(-535, -22), - S(5, -4), S(72, 9), S(165, -51), S(113, 11), S(120, 15), S(51, -3), S(10, 4), S(-11, -42), - S(58, -28), S(84, 28), S(167, 25), S(121, 46), S(118, 37), S(46, 42), S(42, 9), S(-38, 8), - S(107, 9), S(92, 37), S(109, 52), S(91, 77), S(104, 61), S(71, 54), S(58, 3), S(38, 5), - S(100, -8), S(124, -4), S(109, 30), S(90, 39), S(78, 50), S(83, 39), S(60, 6), S(49, -56), - S(28, -31), S(56, -40), S(49, -17), S(62, 26), S(70, 21), S(15, -3), S(21, -44), S(-16, -52), - S(37, -20), S(59, -47), S(40, -39), S(39, -19), S(26, -26), S(-2, -46), S(11, -61), S(-47, -130), - S(-15, -69), S(21, -24), S(44, -50), S(52, -43), S(43, -35), S(-5, -66), S(-18, -41), S(-65, -95) + S(-433, -145), S(-401, 80), S(-499, 264), S(-89, 43), S(-159, 39), S(-265, 7), S(-594, 122), S(-501, -80), + S(-26, 8), S(37, -5), S(126, -42), S(126, -10), S(117, 0), S(52, -11), S(23, 1), S(-14, -38), + S(59, -44), S(101, 6), S(155, 20), S(133, 37), S(99, 39), S(57, 43), S(49, -5), S(-31, 5), + S(96, 8), S(99, 18), S(113, 45), S(89, 77), S(108, 69), S(79, 53), S(70, 11), S(43, -11), + S(90, -5), S(133, -18), S(110, 29), S(104, 44), S(83, 54), S(83, 38), S(72, 0), S(42, -14), + S(26, -28), S(60, -38), S(44, 3), S(66, 28), S(70, 24), S(18, 17), S(21, -33), S(-21, -40), + S(30, -35), S(55, -53), S(40, -38), S(42, -13), S(24, -8), S(0, -46), S(11, -46), S(-35, -109), + S(-24, -64), S(11, -1), S(35, -40), S(51, -43), S(40, -33), S(-9, -51), S(-19, -24), S(-76, -85) }; inline const std::array BISHOP_PSQT = { - S(-179, 73), S(-197, 51), S(-441, 78), S(-318, 90), S(-275, 94), S(-436, 117), S(-175, 94), S(-132, 72), - S(-4, -41), S(-17, 31), S(-2, 15), S(-35, 24), S(-55, 37), S(-4, 23), S(-24, 12), S(-61, 16), - S(32, 13), S(71, 2), S(139, 15), S(69, 13), S(42, 18), S(24, 26), S(90, -6), S(-1, 10), - S(37, -28), S(36, 11), S(73, 11), S(75, 35), S(86, 32), S(24, 30), S(21, 3), S(-18, 6), - S(31, -51), S(51, -16), S(59, 1), S(61, 22), S(55, 37), S(22, 25), S(4, -12), S(-9, -53), - S(54, -48), S(101, -38), S(111, -26), S(46, 16), S(32, 20), S(37, 13), S(62, -29), S(26, -48), - S(45, -83), S(110, -69), S(80, -56), S(51, -25), S(38, -38), S(38, -48), S(21, -35), S(36, -96), - S(45, -67), S(31, -18), S(45, -18), S(56, -52), S(64, -62), S(53, -17), S(48, -47), S(27, -48) + S(-164, 57), S(-178, 51), S(-407, 90), S(-319, 79), S(-264, 90), S(-428, 125), S(-211, 109), S(-141, 71), + S(4, -34), S(-30, 31), S(-27, 10), S(-70, 42), S(-88, 47), S(-34, 34), S(-20, 15), S(-41, 6), + S(42, -5), S(79, -5), S(100, 17), S(52, 29), S(31, 19), S(36, 19), S(61, 3), S(10, -18), + S(38, -30), S(45, 2), S(77, 6), S(62, 46), S(88, 39), S(29, 21), S(26, -9), S(-6, -8), + S(41, -48), S(60, -27), S(63, 2), S(72, 18), S(60, 28), S(21, 26), S(-6, -4), S(-5, -44), + S(50, -41), S(97, -40), S(111, -23), S(60, 13), S(51, 3), S(51, -2), S(61, -22), S(12, -24), + S(42, -91), S(117, -77), S(82, -58), S(57, -35), S(38, -26), S(40, -47), S(35, -42), S(34, -87), + S(44, -71), S(34, -25), S(44, -18), S(55, -49), S(66, -57), S(61, -12), S(50, -43), S(28, -49) }; inline const std::array ROOK_PSQT = { - S(105, 5), S(173, 3), S(100, 33), S(97, 30), S(111, 17), S(62, 29), S(73, 30), S(80, 36), - S(18, 59), S(101, 37), S(171, 18), S(94, 63), S(117, 47), S(68, 55), S(11, 70), S(5, 75), - S(-4, 42), S(141, 3), S(165, 2), S(162, 0), S(123, 7), S(56, 43), S(78, 28), S(-39, 76), - S(-28, 33), S(47, 28), S(74, 22), S(92, -12), S(70, 7), S(14, 51), S(-1, 49), S(-73, 57), - S(-89, -16), S(-8, -10), S(-25, 8), S(-43, 10), S(-41, 2), S(-57, 40), S(-83, 35), S(-105, 26), - S(-110, -38), S(-36, -62), S(-41, -34), S(-62, -32), S(-41, -53), S(-90, 0), S(-91, -19), S(-114, -21), - S(-169, -29), S(-67, -90), S(-46, -73), S(-44, -69), S(-48, -66), S(-68, -49), S(-86, -73), S(-118, -56), - S(-133, -31), S(-104, -23), S(-50, -58), S(-28, -73), S(-40, -59), S(-52, -48), S(-67, -57), S(-85, -40) + S(106, 8), S(163, 7), S(86, 43), S(143, 5), S(123, 6), S(66, 29), S(82, 24), S(100, 24), + S(0, 64), S(78, 42), S(144, 27), S(93, 49), S(115, 36), S(78, 36), S(-3, 70), S(-13, 76), + S(-8, 45), S(114, 8), S(155, -2), S(144, 0), S(120, -3), S(63, 30), S(71, 27), S(-25, 76), + S(-38, 35), S(37, 33), S(82, 19), S(89, -3), S(68, 11), S(21, 37), S(-7, 47), S(-63, 54), + S(-84, -12), S(-8, -6), S(-29, 12), S(-30, 4), S(-32, 2), S(-43, 30), S(-77, 29), S(-94, 22), + S(-109, -40), S(-47, -44), S(-43, -22), S(-59, -29), S(-45, -44), S(-77, 2), S(-82, -20), S(-123, -9), + S(-181, -26), S(-75, -87), S(-50, -66), S(-45, -64), S(-47, -66), S(-63, -49), S(-78, -70), S(-116, -49), + S(-139, -31), S(-113, -20), S(-49, -59), S(-23, -72), S(-39, -58), S(-52, -42), S(-70, -45), S(-88, -26) }; inline const std::array QUEEN_PSQT = { - S(32, 30), S(72, -8), S(69, 0), S(-55, 132), S(17, 48), S(-28, 71), S(47, -15), S(-18, 15), - S(20, 75), S(-52, 172), S(-63, 219), S(-165, 251), S(-136, 200), S(-119, 194), S(-70, 113), S(-33, 50), - S(-23, 102), S(54, 99), S(-16, 178), S(-41, 187), S(-77, 169), S(-102, 172), S(-13, 63), S(-40, 36), - S(42, 9), S(31, 78), S(-5, 116), S(-17, 176), S(-23, 156), S(-29, 98), S(13, 3), S(6, -24), - S(9, 38), S(41, 3), S(15, 71), S(-20, 122), S(-24, 116), S(-17, 79), S(-1, -4), S(-1, -53), - S(24, -114), S(50, -57), S(54, 4), S(2, 18), S(19, -17), S(26, -10), S(37, -80), S(8, -75), - S(15, -212), S(53, -304), S(47, -181), S(61, -116), S(34, -92), S(51, -158), S(26, -95), S(7, -95), - S(-40, -123), S(24, -374), S(29, -375), S(55, -291), S(56, -197), S(57, -245), S(44, -202), S(-11, -125) + S(19, 19), S(70, -20), S(77, -14), S(-29, 118), S(18, 41), S(-14, 64), S(20, -19), S(-56, 50), + S(-1, 88), S(-37, 165), S(-57, 192), S(-176, 235), S(-152, 206), S(-102, 174), S(-64, 80), S(-42, 73), + S(-36, 101), S(41, 105), S(-39, 178), S(-48, 186), S(-96, 177), S(-90, 161), S(-13, 42), S(-49, 40), + S(34, 15), S(32, 73), S(-13, 126), S(-11, 159), S(-34, 165), S(-46, 117), S(14, 17), S(17, -40), + S(11, 15), S(36, 20), S(11, 64), S(-11, 107), S(-10, 98), S(-11, 56), S(8, -8), S(15, -51), + S(18, -97), S(47, -55), S(51, 16), S(20, 2), S(21, -17), S(30, 0), S(34, -59), S(10, -71), + S(8, -201), S(52, -270), S(55, -183), S(69, -125), S(51, -113), S(55, -146), S(20, -79), S(19, -107), + S(-37, -133), S(24, -360), S(31, -372), S(52, -255), S(62, -191), S(57, -239), S(47, -209), S(-6, -137) }; inline const std::array KING_PSQT = { - S(-281, -298), S(18, 20), S(-127, 72), S(-233, 98), S(-19, -5), S(-19, -5), S(-19, -5), S(-19, -5), - S(110, -70), S(81, 88), S(126, 66), S(233, -3), S(-19, -5), S(-19, -5), S(-19, -5), S(-19, -5), - S(-76, 87), S(175, 70), S(243, 28), S(219, -13), S(-19, -5), S(-19, -5), S(-19, -5), S(-19, -5), - S(-284, 108), S(163, 23), S(173, 11), S(114, -6), S(-19, -5), S(-19, -5), S(-19, -5), S(-19, -5), - S(-256, 61), S(94, -10), S(149, -26), S(61, 7), S(-19, -5), S(-19, -5), S(-19, -5), S(-19, -5), - S(-190, 44), S(155, -48), S(113, -23), S(85, -7), S(-19, -5), S(-19, -5), S(-19, -5), S(-19, -5), - S(-30, -18), S(157, -63), S(110, -39), S(45, -6), S(-19, -5), S(-19, -5), S(-19, -5), S(-19, -5), - S(-230, 2), S(-49, -17), S(-127, 5), S(-141, 4), S(-19, -5), S(-19, -5), S(-19, -5), S(-19, -5) + S(-651, -130), S(-65, 85), S(-51, 70), S(-173, 33), S(-2, -9), S(-2, -9), S(-2, -9), S(-2, -9), + S(62, -68), S(62, 119), S(148, 75), S(162, -2), S(-2, -9), S(-2, -9), S(-2, -9), S(-2, -9), + S(-35, 72), S(178, 66), S(222, 44), S(172, 5), S(-2, -9), S(-2, -9), S(-2, -9), S(-2, -9), + S(-287, 84), S(158, 18), S(100, 19), S(66, 11), S(-2, -9), S(-2, -9), S(-2, -9), S(-2, -9), + S(-278, 55), S(79, -9), S(103, -12), S(28, 17), S(-2, -9), S(-2, -9), S(-2, -9), S(-2, -9), + S(-168, 30), S(154, -46), S(111, -21), S(84, -2), S(-2, -9), S(-2, -9), S(-2, -9), S(-2, -9), + S(-10, -35), S(179, -69), S(128, -41), S(61, -7), S(-2, -9), S(-2, -9), S(-2, -9), S(-2, -9), + S(-210, -22), S(-25, -30), S(-108, -1), S(-127, -2), S(-2, -9), S(-2, -9), S(-2, -9), S(-2, -9) }; + +// Epoch duration: 8.05849s // Epoch duration: 6.87929s // clang-format on } // namespace Clockwork diff --git a/src/evaltune_main.cpp b/src/evaltune_main.cpp index f567af24..24ad9696 100644 --- a/src/evaltune_main.cpp +++ b/src/evaltune_main.cpp @@ -32,9 +32,9 @@ int main() { std::vector positions; std::vector results; - const std::vector fenFiles = { - "data/dfrcv1.txt", "data/dfrcv0.txt", "data/v2.2.txt", "data/v2.1.txt", "data/v3.txt", - }; + const std::vector fenFiles = {"data/dfrcv1.txt", "data/dfrcv0.txt", + "data/v3.txt", "v4_5knpm.txt", + "v4_8knpm.txt", "v4_16knpm.txt"}; const u32 thread_count = std::max(1, std::thread::hardware_concurrency() / 2); diff --git a/src/movepick.cpp b/src/movepick.cpp index a0a8f36e..276240b5 100644 --- a/src/movepick.cpp +++ b/src/movepick.cpp @@ -2,6 +2,7 @@ #include "immintrin.h" #include "see.hpp" #include "tuned.hpp" +#include "util/random.hpp" #include #include @@ -172,6 +173,19 @@ std::pair MovePicker::pick_next(MoveList& moves) { return {moves[m_current_index], m_scores[m_current_index++]}; } +Move RandomMovePicker::next() { + if (m_noisy.empty() && m_quiet.empty()) { + return Move::none(); // No moves available + } + // Get a random index in the range [0, quiets.size() + noisies.size()) - 1] + usize idx = Clockwork::Random::rand_64() % (m_noisy.size() + m_quiet.size()); + if (idx < m_noisy.size()) { + return m_noisy[idx]; + } else { + return m_quiet[idx - m_noisy.size()]; + } +} + i32 MovePicker::score_move(Move move) const { if (quiet_move(move)) { return m_history.get_quiet_stats(m_pos, move, m_ply, m_stack); diff --git a/src/movepick.hpp b/src/movepick.hpp index 610d538f..cb0c8aea 100644 --- a/src/movepick.hpp +++ b/src/movepick.hpp @@ -89,4 +89,20 @@ class MovePicker { std::optional m_threshold; }; +// Random movepicker class for data generation. It won't score moves, just pick them randomly. +class RandomMovePicker { +public: + explicit RandomMovePicker(const Position& pos) : + m_movegen(pos) { + m_movegen.generate_moves(m_noisy, m_quiet); + }; + + Move next(); + +private: + MoveGen m_movegen; + MoveList m_noisy; + MoveList m_quiet; +}; + } // namespace Clockwork diff --git a/src/position.cpp b/src/position.cpp index 948c7abf..b80d8731 100644 --- a/src/position.cpp +++ b/src/position.cpp @@ -1083,8 +1083,4 @@ u16 Position::get_50mr_counter() const { return m_50mr; } -u16 Position::get_ply() const { - return m_ply; -} - } // namespace Clockwork diff --git a/src/position.hpp b/src/position.hpp index 5118a786..92fa28e7 100644 --- a/src/position.hpp +++ b/src/position.hpp @@ -231,6 +231,10 @@ struct Position { return true; } + [[nodiscard]] u16 get_ply() const { + return m_ply; + } + [[nodiscard]] bool is_insufficient_material() const { auto wpcnt = piece_count(Color::White); auto bpcnt = piece_count(Color::Black); @@ -269,8 +273,6 @@ struct Position { [[nodiscard]] u16 get_50mr_counter() const; - [[nodiscard]] u16 get_ply() const; - [[nodiscard]] bool is_reversible(Move move); const std::array calc_attacks_slow(); diff --git a/src/search.cpp b/src/search.cpp index 6f122f03..07f9fc98 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -90,6 +90,17 @@ void Searcher::wait() { std::unique_lock lock_guard{mutex}; } +Value Searcher::wait_for_score() { + // Make sure this is being used on only the main thread. + if (m_workers.empty() || m_workers[0]->thread_type() != ThreadType::MAIN) { + throw std::logic_error("wait_for_score can only be called from the main thread"); + } + // Protect the read of root_score with a unique_lock. + std::unique_lock lock_guard{mutex}; + // Return the final score from the main thread's search. + return m_workers[0]->get_thread_data().root_score; +} + void Searcher::initialize(size_t thread_count) { if (m_workers.size() == thread_count) { return; @@ -312,6 +323,8 @@ Move Worker::iterative_deepening(const Position& root_position) { last_best_move = last_pv.first_move(); base_search_score = search_depth == 1 ? score : base_search_score; + m_td.root_score = last_search_score; + // Check depth limit if (IS_MAIN && search_depth >= m_search_limits.depth_limit) { break; diff --git a/src/search.hpp b/src/search.hpp index 53a8f07c..427bbb84 100644 --- a/src/search.hpp +++ b/src/search.hpp @@ -28,6 +28,7 @@ struct SearchSettings { u64 hard_nodes = 0; u64 soft_nodes = 0; bool silent = false; + bool datagen = false; }; // Forward declare for Searcher @@ -80,6 +81,7 @@ struct SearchLimits { struct ThreadData { History history; std::vector psqt_states; + Value root_score; PsqtState& push_psqt_state() { psqt_states.push_back(psqt_states.back()); @@ -109,12 +111,13 @@ class Searcher { Searcher(); ~Searcher(); - void set_position(const Position& root_position, const RepetitionInfo& repetition_info); - void launch_search(SearchSettings settings); - void stop_searching(); - void wait(); - void initialize(size_t thread_count); - void exit(); + void set_position(const Position& root_position, const RepetitionInfo& repetition_info); + void launch_search(SearchSettings settings); + void stop_searching(); + void wait(); + Value wait_for_score(); + void initialize(size_t thread_count); + void exit(); u64 node_count(); void reset(); @@ -153,6 +156,10 @@ class alignas(128) Worker { return m_search_nodes.load(std::memory_order_relaxed); } + [[nodiscard]] const ThreadData& get_thread_data() const { + return m_td; + } + [[nodiscard]] Value get_draw_score() const { return (search_nodes() & 3) - 2; // Randomize between -2 and +2 } diff --git a/src/uci.cpp b/src/uci.cpp index 175b52eb..15393f1b 100644 --- a/src/uci.cpp +++ b/src/uci.cpp @@ -2,6 +2,7 @@ #include "bench.hpp" #include "evaluation.hpp" #include "move.hpp" +#include "movepick.hpp" #include "perft.hpp" #include "position.hpp" #include "search.hpp" @@ -304,9 +305,10 @@ void UCIHandler::handle_perft(std::istringstream& is) { } void UCIHandler::handle_genfens(std::istringstream& is) { - int N = 0; + i32 N = 0; uint64_t seed = 0; bool seed_provided = false; + i32 rand_count = 4; std::string book = "None"; std::string token; @@ -328,67 +330,104 @@ void UCIHandler::handle_genfens(std::istringstream& is) { std::cout << "Missing book filename after 'book'." << std::endl; return; } + } else if (token == "randmoves") { + if (!(is >> rand_count) || rand_count < 0) { + std::cout << "Invalid randmoves value." << std::endl; + return; + } } else { std::cout << "Invalid genfens argument: " << token << std::endl; return; } } - // Require book file - if (book == "None") { - std::cout << "Please specify a book file using 'book '." << std::endl; - return; - } - // Set RNG state if (!seed_provided) { std::cout << "Seed not provided. Defaulting to 0." << std::endl; } Clockwork::Random::seed({seed, seed, seed | 1, seed ^ 0xDEADBEEFDEADBEEFULL}); - // Open the book file - std::ifstream file(book); - if (!file) { - std::cout << "Could not open file: " << book << std::endl; - return; - } - - // Load all lines std::vector lines; std::string line; - while (std::getline(file, line)) { - if (!line.empty()) { - lines.push_back(line); + if (book != "None") { + std::cout << "Using book file: " << book << std::endl; + + // Open the book file + std::ifstream file(book); + if (!file) { + std::cout << "Could not open file: " << book << std::endl; + return; } - } - // Safety checks - if (lines.empty()) { - std::cout << "Book file is empty." << std::endl; - return; - } + // Load all lines + while (std::getline(file, line)) { + if (!line.empty()) { + lines.push_back(line); + } + } - if (N > static_cast(lines.size())) { - std::cout << "Requested " << N << " positions, but only " << lines.size() << " available." - << std::endl; - return; + if (lines.empty()) { + std::cout << "Book file is empty." << std::endl; + return; + } + file.close(); + } else { + // Add startpositions to lines + lines.push_back("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"); } - // Pick N unique random indices - std::unordered_set selected_indices; - while (selected_indices.size() < static_cast(N)) { - uint64_t rand_val = Clockwork::Random::rand_64(); - selected_indices.insert(rand_val % lines.size()); - } + // Line generation is as follows: + // 1) Pick a random line from the book (or startposition if no book) + // 2) Play random legal moves (see passing noisy or quiet moves) + // Launch a 16k softnodes verification search to make sure the position doesn't lose immediately + // 3) If the position is legal, print it + i32 generated = 0; + + while (generated < N) { +reset: + // Pick a random line from the book + const std::string& selected_line = lines[Clockwork::Random::rand_64() % lines.size()]; + + // Set up position + Position pos = *Position::parse(selected_line); + + i32 moves = 0; + // Make 4 random moves out of book + while (moves < rand_count) { + RandomMovePicker picker(pos); + Move m = picker.next(); + if (m == Move::none()) { + // No moves available, skip + goto reset; + } + pos = pos.move(m); + moves++; + } + + // Mock search limits for datagen verification + Search::SearchSettings settings = { + .stm = pos.active_color(), .hard_nodes = 1048576, .soft_nodes = 16384, .silent = true}; + + searcher.initialize(1); // Initialize with 1 thread always for datagen + + RepetitionInfo rep_info; + rep_info.reset(); + rep_info.push(pos.get_hash_key(), false); - // Output the selected FENs - for (size_t idx : selected_indices) { - auto pos = Position::parse(lines[idx]); - if (!pos) { - std::cout << "Invalid FEN in book: " << lines[idx] << std::endl; - exit(-1); + searcher.set_position(pos, rep_info); + searcher.launch_search(settings); + + // Wait for the search to finish and get the score + Value score = searcher.wait_for_score(); + + if (std::abs(score) > 450) { + // Position is mate or losing, skip + goto reset; } - std::cout << "info string genfens " << *pos << std::endl; + + // If we reach here, the position is legal + std::cout << "info string genfens " << pos << std::endl; + generated++; } }