diff --git a/configs/mixed_landmark_0814_no_extend_qsa.json b/configs/mixed_landmark_0814_no_extend_qsa.json index 1fc11ac2..c69ea1e5 100644 --- a/configs/mixed_landmark_0814_no_extend_qsa.json +++ b/configs/mixed_landmark_0814_no_extend_qsa.json @@ -7,24 +7,74 @@ "__delta_attention_args": "window_0-diff_1-w_16-dense_decode-smooth", "using_extend": false, "dense_layers": [0, 1, 2, 47, 46, 45], - "mask_refresh_interval": [96], + "mask_refresh_interval": [96, 32, 16], "layers": [ { "sliding_window_size": 1024, "sliding_window_size_for_masking_step": [1024, 1024, 1024], - "second_stage_k": 1024, + "second_stage_k": 2048, "sink_token_size": 1024, "sa_extend_backend": "self_extend", - "stages": [ { } ] + "stages": [ + { + "stage_block_size_q":128, + "stage_block_stride_q":4, + "stage_chunk_size":256, + "stage_k":null, + "stage_stride":1, + "using_landmark":false + }, + { + "stage_block_size_q":64, + "stage_block_stride_q":1, + "stage_chunk_size":32, + "stage_k":65536, + "stage_stride":1, + "using_landmark":false + }, + { + "stage_block_size_q":64, + "stage_block_stride_q":1, + "stage_chunk_size":8, + "stage_k":8192, + "stage_stride":1, + "using_landmark":false + } + ] }, { "sliding_window_size": 1024, "sliding_window_size_for_masking_step": [1024, 1024, 1024], - "second_stage_k": 1024, + "second_stage_k": 2048, "sink_token_size": 1024, "sa_extend_backend": "self_extend", "scan_extend_backend": "none", - "stages": [ { } ] + "stages": [ + { + "stage_block_size_q":128, + "stage_block_stride_q":4, + "stage_chunk_size":256, + "stage_k":null, + "stage_stride":1, + "using_landmark":false + }, + { + "stage_block_size_q":64, + "stage_block_stride_q":1, + "stage_chunk_size":32, + "stage_k":65536, + "stage_stride":1, + "using_landmark":false + }, + { + "stage_block_size_q":64, + "stage_block_stride_q":1, + "stage_chunk_size":8, + "stage_k":8192, + "stage_stride":1, + "using_landmark":false + } + ] } ], "prefill_layers": [ diff --git a/src/hip_attn/utils/sglang_watchdog.py b/src/hip_attn/utils/sglang_watchdog.py new file mode 100644 index 00000000..0baca976 --- /dev/null +++ b/src/hip_attn/utils/sglang_watchdog.py @@ -0,0 +1,148 @@ +import argparse +import datetime +import sys +import os +import subprocess +import threading +import time +import traceback +import requests + +def log(*args): + comment = " ".join([str(a) for a in args]) + timestamp = "{:%Y-%m-%d %H:%M:%S}".format(datetime.datetime.now()) + print(f"\033[91m[{timestamp} sglang_watchdog] {comment}\033[0m", flush=True) + +class Watchdog: + def __init__( + self, + ): + self.timeout_bootup = 600 + self.timeout_tick = 60 + self.sleep_step = 1 + self.proc: subprocess.Popen = None + self.argv: list[str] = None + self.running: bool = True + + def start_subprocess(self): + args = [ + "python", + "-m", + "sglang.launch_server", + *self.argv + ] + flatten_args = " ".join(args) + log(f"Start subprocess using following command: {flatten_args}") + self.proc = subprocess.Popen(args) + log(f"Start subprocess communication.") + return_code = self.proc.wait() + log(f"Return code is {return_code}") + + def kill_subprocess(self): + log(f"Start kill subprocess") + if self.proc is not None: + self.proc.kill() + self.proc = None + subprocess.call(["pkill", "sglang"]) + log(f"Finish kill subprocess") + + def wait_for_health(self, timeout: int): + response = requests.get(self.health_endpoint, timeout=timeout) + response.raise_for_status() + + def main_watchdog(self): + while True: + try: + t_boot = time.time() + booted = False + while self.proc is None: + log("Watchdog is waiting for process started...") + time.sleep(self.sleep_step) + while ( + (time.time() - t_boot) < self.timeout_bootup + and self.proc.returncode is None + and not booted + ): + try: + self.wait_for_health(timeout=self.timeout_bootup) + log("Server booted successfully.") + booted = True + except (TimeoutError, requests.HTTPError, requests.ConnectionError): + # NOTE: may process is not started yet + pass + time.sleep(self.sleep_step) + + if not booted: raise TimeoutError() + + while True: + log("Try watch dog.") + self.wait_for_health(timeout=self.timeout_tick) + log("Done watch dog successfully.") + time.sleep(self.timeout_tick) + + except (TimeoutError, requests.HTTPError): + self.kill_subprocess() + except Exception as ex: + trace = traceback.format_exc() + log(f"Traceback:\n{trace}") + log(f"Unexpected error on watchdog thread: {ex}") + self.kill_subprocess() + + time.sleep(self.sleep_step) + + def main_starter(self): + while True: + self.start_subprocess() + time.sleep(self.sleep_step) + + def start(self): + try: + if "--" in sys.argv: + my_args = sys.argv[1:sys.argv.index("--")] + argv = sys.argv[sys.argv.index("--") + 1:] + else: + my_args = [] + argv = sys.argv[1:] + + parser = argparse.ArgumentParser() + parser.add_argument("--timeout-bootup", default=self.timeout_bootup, type=int) + parser.add_argument("--timeout", default=self.timeout_tick, type=int) + parser.add_argument("--sleep-step", default=self.sleep_step, type=int) + + args = parser.parse_args(my_args) + self.timeout_bootup = args.timeout_bootup + self.timeout_tick = args.timeout + self.sleep_step = args.sleep_step + + assert "--host" in argv + assert "--port" in argv + self.host = argv[argv.index("--host") + 1] + self.port = argv[argv.index("--port") + 1] + self.health_endpoint = f"http://{self.host}:{self.port}/health" + log(f"Watching: {self.health_endpoint}") + + self.argv = argv + + self.thread_watchdog = threading.Thread( + target=self.main_watchdog, + daemon=True + ) + self.thread_starter = threading.Thread( + target=self.main_starter, + daemon=True + ) + + self.thread_starter.start() + time.sleep(self.sleep_step) + self.thread_watchdog.start() + + self.thread_watchdog.join() + self.thread_starter.join() + + self.running = False + except KeyboardInterrupt: + self.kill_subprocess() + +if __name__ == '__main__': + dog = Watchdog() + dog.start() \ No newline at end of file diff --git a/src/hip_attn/v1_2/attention_extend.py b/src/hip_attn/v1_2/attention_extend.py index 36ddf535..449c2c02 100644 --- a/src/hip_attn/v1_2/attention_extend.py +++ b/src/hip_attn/v1_2/attention_extend.py @@ -16,7 +16,12 @@ from hip_attn.utils.rope import adjust_rope from hip_attn.v1_2.attention_decode_bsa import decode_block_sparse_attention from hip_attn.v1_2.attention_extend_bsa import block_sparse_attention -from hip_attn.v1_2.attention_extend_bsa_tilelang import block_sparse_attention_tilelang + +try: + from hip_attn.v1_2.attention_extend_bsa_tilelang import block_sparse_attention_tilelang +except (ImportError, OSError): + block_sparse_attention_tilelang = None + from hip_attn.v1_2.attention_metadata import ( EnsembleScoreStage, EvalScoreStage, diff --git a/src/hip_attn/v1_2/paged_hip.py b/src/hip_attn/v1_2/paged_hip.py index 58fca4b9..6a5f4c1a 100644 --- a/src/hip_attn/v1_2/paged_hip.py +++ b/src/hip_attn/v1_2/paged_hip.py @@ -2,17 +2,85 @@ import math import os import warnings -from typing import Any, List, Optional +from typing import Any, List, Optional, Union import cv2 import numba import numpy as np import torch import triton -from flash_attn import flash_attn_func from matplotlib import pyplot as plt -from sgl_kernel.flash_attn import flash_attn_varlen_func as __flash_attn_varlen_func -from sgl_kernel.flash_attn import flash_attn_with_kvcache + +try: + from flash_attn import flash_attn_func +except ImportError: + flash_attn_func = None + +try: + from sgl_kernel.flash_attn import flash_attn_varlen_func as __flash_attn_varlen_func + from sgl_kernel.flash_attn import flash_attn_with_kvcache + IS_AMD = False +except ImportError: + # FIXME: better AMD detection algorithm + IS_AMD = True + + from flash_attn import flash_attn_varlen_func as __flash_attn_varlen_func + from flash_attn import flash_attn_with_kvcache as __flash_attn_with_kvcache + + def flash_attn_with_kvcache( + q, + k_cache, + v_cache, + k=None, + v=None, + qv=None, + rotary_cos=None, + rotary_sin=None, + cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + rotary_seqlens: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + rotary_interleaved=True, + scheduler_metadata=None, + num_splits=0, # Can be tuned for speed + pack_gqa=None, # Can be tuned for speed + sm_margin=0, # Can be tuned if some SMs are used for communication + return_softmax_lse=False, + sinks=None, + ver=3, + ): + return __flash_attn_with_kvcache( + q, + k_cache, + v_cache, + k=k, + v=v, + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + cache_seqlens=cache_seqlens, + cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, + block_table=page_table, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, # -1 means infinite context window + softcap=softcap, # 0.0 means deactivated + rotary_interleaved=rotary_interleaved, + alibi_slopes=None, + num_splits=num_splits, + return_softmax_lse=return_softmax_lse, + ) from hip_attn.v1_2.hip_config import HiPAttentionConfig from hip_attn.v1_2.utils import capture diff --git a/src/hip_attn/v1_2/query_sparse_attention.py b/src/hip_attn/v1_2/query_sparse_attention.py index d001e213..c862ba1f 100644 --- a/src/hip_attn/v1_2/query_sparse_attention.py +++ b/src/hip_attn/v1_2/query_sparse_attention.py @@ -1907,7 +1907,7 @@ def forward( assert rope_cos.ndim == 2 assert extend_backend in ["self_extend", "nope"] - if rope_sin is not None: + if (rope_sin is not None) and (extend_backend in ["self_extend"]): HEAD_DIM_K_ROPE = rope_sin.shape[-1] HEAD_DIM_K_NOPE = HEAD_DIM_K - HEAD_DIM_K_ROPE else: