Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 55 additions & 5 deletions configs/mixed_landmark_0814_no_extend_qsa.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,74 @@
"__delta_attention_args": "window_0-diff_1-w_16-dense_decode-smooth",
"using_extend": false,
"dense_layers": [0, 1, 2, 47, 46, 45],
"mask_refresh_interval": [96],
"mask_refresh_interval": [96, 32, 16],
"layers": [
{
"sliding_window_size": 1024,
"sliding_window_size_for_masking_step": [1024, 1024, 1024],
"second_stage_k": 1024,
"second_stage_k": 2048,
"sink_token_size": 1024,
"sa_extend_backend": "self_extend",
"stages": [ { } ]
"stages": [
{
"stage_block_size_q":128,
"stage_block_stride_q":4,
"stage_chunk_size":256,
"stage_k":null,
"stage_stride":1,
"using_landmark":false
},
{
"stage_block_size_q":64,
"stage_block_stride_q":1,
"stage_chunk_size":32,
"stage_k":65536,
"stage_stride":1,
"using_landmark":false
},
{
"stage_block_size_q":64,
"stage_block_stride_q":1,
"stage_chunk_size":8,
"stage_k":8192,
"stage_stride":1,
"using_landmark":false
}
]
},
{
"sliding_window_size": 1024,
"sliding_window_size_for_masking_step": [1024, 1024, 1024],
"second_stage_k": 1024,
"second_stage_k": 2048,
"sink_token_size": 1024,
"sa_extend_backend": "self_extend",
"scan_extend_backend": "none",
"stages": [ { } ]
"stages": [
{
"stage_block_size_q":128,
"stage_block_stride_q":4,
"stage_chunk_size":256,
"stage_k":null,
"stage_stride":1,
"using_landmark":false
},
{
"stage_block_size_q":64,
"stage_block_stride_q":1,
"stage_chunk_size":32,
"stage_k":65536,
"stage_stride":1,
"using_landmark":false
},
{
"stage_block_size_q":64,
"stage_block_stride_q":1,
"stage_chunk_size":8,
"stage_k":8192,
"stage_stride":1,
"using_landmark":false
}
]
}
],
Comment on lines 8 to 79
Copy link
Contributor

@kbumsik kbumsik Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose it is used for qwen3.

So is this PR also fixes our qwen3 + nvidia? To me this PR can't be ignored even if we don't need AMD.

"prefill_layers": [
Expand Down
148 changes: 148 additions & 0 deletions src/hip_attn/utils/sglang_watchdog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import argparse
import datetime
import sys
import os
import subprocess
import threading
import time
import traceback
import requests

def log(*args):
comment = " ".join([str(a) for a in args])
timestamp = "{:%Y-%m-%d %H:%M:%S}".format(datetime.datetime.now())
print(f"\033[91m[{timestamp} sglang_watchdog] {comment}\033[0m", flush=True)

class Watchdog:
def __init__(
self,
):
self.timeout_bootup = 600
self.timeout_tick = 60
self.sleep_step = 1
self.proc: subprocess.Popen = None
self.argv: list[str] = None
self.running: bool = True

def start_subprocess(self):
args = [
"python",
"-m",
"sglang.launch_server",
*self.argv
]
flatten_args = " ".join(args)
log(f"Start subprocess using following command: {flatten_args}")
self.proc = subprocess.Popen(args)
log(f"Start subprocess communication.")
return_code = self.proc.wait()
log(f"Return code is {return_code}")

def kill_subprocess(self):
log(f"Start kill subprocess")
if self.proc is not None:
self.proc.kill()
self.proc = None
subprocess.call(["pkill", "sglang"])
log(f"Finish kill subprocess")

def wait_for_health(self, timeout: int):
response = requests.get(self.health_endpoint, timeout=timeout)
response.raise_for_status()

def main_watchdog(self):
while True:
try:
t_boot = time.time()
booted = False
while self.proc is None:
log("Watchdog is waiting for process started...")
time.sleep(self.sleep_step)
while (
(time.time() - t_boot) < self.timeout_bootup
and self.proc.returncode is None
and not booted
):
try:
self.wait_for_health(timeout=self.timeout_bootup)
log("Server booted successfully.")
booted = True
except (TimeoutError, requests.HTTPError, requests.ConnectionError):
# NOTE: may process is not started yet
pass
time.sleep(self.sleep_step)

if not booted: raise TimeoutError()

while True:
log("Try watch dog.")
self.wait_for_health(timeout=self.timeout_tick)
log("Done watch dog successfully.")
time.sleep(self.timeout_tick)

except (TimeoutError, requests.HTTPError):
self.kill_subprocess()
except Exception as ex:
trace = traceback.format_exc()
log(f"Traceback:\n{trace}")
log(f"Unexpected error on watchdog thread: {ex}")
self.kill_subprocess()

time.sleep(self.sleep_step)

def main_starter(self):
while True:
self.start_subprocess()
time.sleep(self.sleep_step)

def start(self):
try:
if "--" in sys.argv:
my_args = sys.argv[1:sys.argv.index("--")]
argv = sys.argv[sys.argv.index("--") + 1:]
else:
my_args = []
argv = sys.argv[1:]

parser = argparse.ArgumentParser()
parser.add_argument("--timeout-bootup", default=self.timeout_bootup, type=int)
parser.add_argument("--timeout", default=self.timeout_tick, type=int)
parser.add_argument("--sleep-step", default=self.sleep_step, type=int)

args = parser.parse_args(my_args)
self.timeout_bootup = args.timeout_bootup
self.timeout_tick = args.timeout
self.sleep_step = args.sleep_step

assert "--host" in argv
assert "--port" in argv
self.host = argv[argv.index("--host") + 1]
self.port = argv[argv.index("--port") + 1]
self.health_endpoint = f"http://{self.host}:{self.port}/health"
log(f"Watching: {self.health_endpoint}")

self.argv = argv

self.thread_watchdog = threading.Thread(
target=self.main_watchdog,
daemon=True
)
self.thread_starter = threading.Thread(
target=self.main_starter,
daemon=True
)

self.thread_starter.start()
time.sleep(self.sleep_step)
self.thread_watchdog.start()

self.thread_watchdog.join()
self.thread_starter.join()

self.running = False
except KeyboardInterrupt:
self.kill_subprocess()

if __name__ == '__main__':
dog = Watchdog()
dog.start()
7 changes: 6 additions & 1 deletion src/hip_attn/v1_2/attention_extend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@
from hip_attn.utils.rope import adjust_rope
from hip_attn.v1_2.attention_decode_bsa import decode_block_sparse_attention
from hip_attn.v1_2.attention_extend_bsa import block_sparse_attention
from hip_attn.v1_2.attention_extend_bsa_tilelang import block_sparse_attention_tilelang

try:
from hip_attn.v1_2.attention_extend_bsa_tilelang import block_sparse_attention_tilelang
except (ImportError, OSError):
block_sparse_attention_tilelang = None

from hip_attn.v1_2.attention_metadata import (
EnsembleScoreStage,
EvalScoreStage,
Expand Down
76 changes: 72 additions & 4 deletions src/hip_attn/v1_2/paged_hip.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,85 @@
import math
import os
import warnings
from typing import Any, List, Optional
from typing import Any, List, Optional, Union

import cv2
import numba
import numpy as np
import torch
import triton
from flash_attn import flash_attn_func
from matplotlib import pyplot as plt
from sgl_kernel.flash_attn import flash_attn_varlen_func as __flash_attn_varlen_func
from sgl_kernel.flash_attn import flash_attn_with_kvcache

try:
from flash_attn import flash_attn_func
except ImportError:
flash_attn_func = None

try:
from sgl_kernel.flash_attn import flash_attn_varlen_func as __flash_attn_varlen_func
from sgl_kernel.flash_attn import flash_attn_with_kvcache
IS_AMD = False
except ImportError:
# FIXME: better AMD detection algorithm
IS_AMD = True

from flash_attn import flash_attn_varlen_func as __flash_attn_varlen_func
from flash_attn import flash_attn_with_kvcache as __flash_attn_with_kvcache

def flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k=None,
v=None,
qv=None,
rotary_cos=None,
rotary_sin=None,
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k_new: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
rotary_seqlens: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
rotary_interleaved=True,
scheduler_metadata=None,
num_splits=0, # Can be tuned for speed
pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication
return_softmax_lse=False,
sinks=None,
ver=3,
):
return __flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k=k,
v=v,
rotary_cos=rotary_cos,
rotary_sin=rotary_sin,
cache_seqlens=cache_seqlens,
cache_batch_idx=cache_batch_idx,
cache_leftpad=cache_leftpad,
block_table=page_table,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size, # -1 means infinite context window
softcap=softcap, # 0.0 means deactivated
rotary_interleaved=rotary_interleaved,
alibi_slopes=None,
num_splits=num_splits,
return_softmax_lse=return_softmax_lse,
)

from hip_attn.v1_2.hip_config import HiPAttentionConfig
from hip_attn.v1_2.utils import capture
Expand Down
2 changes: 1 addition & 1 deletion src/hip_attn/v1_2/query_sparse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1907,7 +1907,7 @@ def forward(
assert rope_cos.ndim == 2
assert extend_backend in ["self_extend", "nope"]

if rope_sin is not None:
if (rope_sin is not None) and (extend_backend in ["self_extend"]):
HEAD_DIM_K_ROPE = rope_sin.shape[-1]
HEAD_DIM_K_NOPE = HEAD_DIM_K - HEAD_DIM_K_ROPE
else:
Expand Down
Loading