Skip to content

Commit e3ce2be

Browse files
Merge pull request #5 from adjust/saber/redis-stream
Feat(Stream): Use redis stream
2 parents c2f1b77 + 81f9565 commit e3ce2be

File tree

6 files changed

+311
-36
lines changed

6 files changed

+311
-36
lines changed

arq/connections.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414
from redis.asyncio.sentinel import Sentinel
1515
from redis.exceptions import RedisError, WatchError
1616

17-
from .constants import default_queue_name, expires_extra_ms, job_key_prefix, result_key_prefix
17+
from .constants import default_queue_name, expires_extra_ms, job_key_prefix, result_key_prefix, stream_key_suffix, \
18+
job_message_id_prefix
1819
from .jobs import Deserializer, Job, JobDef, JobResult, Serializer, deserialize_job, serialize_job
20+
from .lua_script import publish_job_lua
1921
from .utils import timestamp_ms, to_ms, to_unix_ms
2022

2123
logger = logging.getLogger('arq.connections')
@@ -117,6 +119,7 @@ def __init__(
117119
if pool_or_conn:
118120
kwargs['connection_pool'] = pool_or_conn
119121
self.expires_extra_ms = expires_extra_ms
122+
self.publish_job_sha = None
120123
super().__init__(**kwargs)
121124

122125
async def enqueue_job(
@@ -129,6 +132,7 @@ async def enqueue_job(
129132
_defer_by: Union[None, int, float, timedelta] = None,
130133
_expires: Union[None, int, float, timedelta] = None,
131134
_job_try: Optional[int] = None,
135+
_use_stream: bool = False,
132136
distribution: Optional[str] = None, # example 5:2
133137
**kwargs: Any,
134138
) -> Optional[Job]:
@@ -167,6 +171,9 @@ async def enqueue_job(
167171
defer_by_ms = to_ms(_defer_by)
168172
expires_ms = to_ms(_expires)
169173

174+
if _use_stream is True and self.publish_job_sha is None:
175+
self.publish_job_sha = await self.script_load(publish_job_lua)
176+
170177
async with self.pipeline(transaction=True) as pipe:
171178
await pipe.watch(job_key)
172179
if await pipe.exists(job_key, result_key_prefix + job_id):
@@ -186,14 +193,37 @@ async def enqueue_job(
186193
job = serialize_job(function, args, kwargs, _job_try, enqueue_time_ms, serializer=self.job_serializer)
187194
pipe.multi()
188195
pipe.psetex(job_key, expires_ms, job)
189-
pipe.zadd(_queue_name, {job_id: score})
196+
197+
if _use_stream is False:
198+
pipe.zadd(_queue_name, {job_id: score})
199+
else:
200+
stream_key = _queue_name + stream_key_suffix
201+
job_message_id_key = job_message_id_prefix + job_id
202+
203+
pipe.evalsha(
204+
self.publish_job_sha,
205+
2,
206+
# keys
207+
stream_key,
208+
job_message_id_key,
209+
# args
210+
job_id,
211+
str(enqueue_time_ms),
212+
str(expires_ms),
213+
)
190214
try:
191215
await pipe.execute()
192216
except WatchError:
193217
# job got enqueued since we checked 'job_exists'
194218
return None
195219
return Job(job_id, redis=self, _queue_name=_queue_name, _deserializer=self.job_deserializer)
196220

221+
async def get_stream_size(self, queue_name: str | None = None, include_delayed_tasks: bool = True) -> int:
222+
if queue_name is None:
223+
queue_name = self.default_queue_name
224+
225+
return await self.xlen(queue_name + stream_key_suffix)
226+
197227
def _get_queue_index(self, distribution: Optional[str]) -> int:
198228
ratios = list(map(lambda x: int(x), distribution.split(':'))) # type: ignore[union-attr]
199229
ratios_sum = sum(ratios)

arq/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
default_queue_name = 'arq:queue'
22
job_key_prefix = 'arq:job:'
33
in_progress_key_prefix = 'arq:in-progress:'
4+
job_message_id_prefix = 'arq:message-id:'
45
result_key_prefix = 'arq:result:'
56
retry_key_prefix = 'arq:retry:'
67
abort_jobs_ss = 'arq:abort'
8+
stream_key_suffix = ':stream'
9+
default_consumer_group = 'arq:consumers'
710
# age of items in the abort_key sorted set after which they're deleted
811
abort_job_max_age = 60
912
health_check_key_suffix = ':health-check'

arq/jobs.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99

1010
from redis.asyncio import Redis
1111

12-
from .constants import abort_jobs_ss, default_queue_name, in_progress_key_prefix, job_key_prefix, result_key_prefix
13-
from .utils import ms_to_datetime, poll, timestamp_ms
12+
from .constants import abort_jobs_ss, default_queue_name, in_progress_key_prefix, job_key_prefix, result_key_prefix, \
13+
stream_key_suffix, job_message_id_prefix
14+
from .lua_script import get_job_from_stream_lua
15+
from .utils import ms_to_datetime, poll, timestamp_ms, _list_to_dict
1416

1517
logger = logging.getLogger('arq.jobs')
1618

@@ -105,7 +107,8 @@ async def result(
105107
async with self._redis.pipeline(transaction=True) as tr:
106108
tr.get(result_key_prefix + self.job_id)
107109
tr.zscore(self._queue_name, self.job_id)
108-
v, s = await tr.execute()
110+
tr.get(job_message_id_prefix + self.job_id)
111+
v, s, m = await tr.execute()
109112

110113
if v:
111114
info = deserialize_result(v, deserializer=self._deserializer)
@@ -115,7 +118,7 @@ async def result(
115118
raise info.result
116119
else:
117120
raise SerializationError(info.result)
118-
elif s is None:
121+
elif s is None and m is None:
119122
raise ResultNotFound(
120123
'Not waiting for job result because the job is not in queue. '
121124
'Is the worker function configured to keep result?'
@@ -134,8 +137,23 @@ async def info(self) -> Optional[JobDef]:
134137
if v:
135138
info = deserialize_job(v, deserializer=self._deserializer)
136139
if info:
137-
s = await self._redis.zscore(self._queue_name, self.job_id)
138-
info.score = None if s is None else int(s)
140+
async with self._redis.pipeline(transaction=True) as tr:
141+
tr.zscore(self._queue_name, self.job_id)
142+
tr.eval(
143+
get_job_from_stream_lua,
144+
2,
145+
self._queue_name + stream_key_suffix,
146+
job_message_id_prefix + self.job_id,
147+
)
148+
delayed_score, job_info = await tr.execute()
149+
150+
if delayed_score:
151+
info.score = int(delayed_score)
152+
elif job_info:
153+
_, job_info_payload = job_info
154+
info.score = int(_list_to_dict(job_info_payload)[b'score'])
155+
else:
156+
info.score = None
139157
return info
140158

141159
async def result_info(self) -> Optional[JobResult]:
@@ -157,12 +175,15 @@ async def status(self) -> JobStatus:
157175
tr.exists(result_key_prefix + self.job_id)
158176
tr.exists(in_progress_key_prefix + self.job_id)
159177
tr.zscore(self._queue_name, self.job_id)
160-
is_complete, is_in_progress, score = await tr.execute()
178+
tr.exists(job_message_id_prefix + self.job_id)
179+
is_complete, is_in_progress, score, queued = await tr.execute()
161180

162181
if is_complete:
163182
return JobStatus.complete
164183
elif is_in_progress:
165184
return JobStatus.in_progress
185+
elif queued:
186+
return JobStatus.queued
166187
elif score:
167188
return JobStatus.deferred if score > timestamp_ms() else JobStatus.queued
168189
else:

arq/lua_script.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
publish_job_lua = """
2+
local stream_key = KEYS[1]
3+
local job_message_id_key = KEYS[2]
4+
local job_id = ARGV[1]
5+
local score = ARGV[2]
6+
local job_message_id_expire_ms = ARGV[3]
7+
local message_id = redis.call('xadd', stream_key, '*', 'job_id', job_id, 'score', score)
8+
redis.call('set', job_message_id_key, message_id, 'px', job_message_id_expire_ms)
9+
return message_id
10+
"""
11+
12+
get_job_from_stream_lua = """
13+
local stream_key = KEYS[1]
14+
local job_message_id_key = KEYS[2]
15+
local message_id = redis.call('get', job_message_id_key)
16+
if message_id == false then
17+
return nil
18+
end
19+
local job = redis.call('xrange', stream_key, message_id, message_id)
20+
if job == nil then
21+
return nil
22+
end
23+
return job[1]
24+
"""

arq/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,6 @@ def import_string(dotted_path: str) -> Any:
148148
return getattr(module, class_name)
149149
except AttributeError as e:
150150
raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute') from e
151+
152+
def _list_to_dict(input_list: list[Any]) -> dict[Any, Any]:
153+
return dict(zip(input_list[::2], input_list[1::2], strict=True))

0 commit comments

Comments
 (0)