Skip to content

Commit b295f9a

Browse files
authored
FEAT: support getting progress for image model (#2395)
1 parent 20774af commit b295f9a

File tree

12 files changed

+575
-79
lines changed

12 files changed

+575
-79
lines changed

xinference/api/restful_api.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,16 @@ async def internal_exception_handler(request: Request, exc: Exception):
524524
else None
525525
),
526526
)
527+
self._router.add_api_route(
528+
"/v1/requests/{request_id}/progress",
529+
self.get_progress,
530+
methods=["get"],
531+
dependencies=(
532+
[Security(self._auth_service, scopes=["models:read"])]
533+
if self.is_authenticated()
534+
else None
535+
),
536+
)
527537
self._router.add_api_route(
528538
"/v1/images/generations",
529539
self.create_images,
@@ -1486,6 +1496,17 @@ async def create_speech(
14861496
await self._report_error_event(model_uid, str(e))
14871497
raise HTTPException(status_code=500, detail=str(e))
14881498

1499+
async def get_progress(self, request_id: str) -> JSONResponse:
1500+
try:
1501+
supervisor_ref = await self._get_supervisor_ref()
1502+
result = {"progress": await supervisor_ref.get_progress(request_id)}
1503+
return JSONResponse(content=result)
1504+
except KeyError as e:
1505+
raise HTTPException(status_code=400, detail=str(e))
1506+
except Exception as e:
1507+
logger.error(e, exc_info=True)
1508+
raise HTTPException(status_code=500, detail=str(e))
1509+
14891510
async def create_images(self, request: Request) -> Response:
14901511
body = TextToImageRequest.parse_obj(await request.json())
14911512
model_uid = body.model

xinference/client/restful/restful_client.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1385,6 +1385,16 @@ def get_supervisor_info(self):
13851385
response_json = response.json()
13861386
return response_json
13871387

1388+
def get_progress(self, request_id: str):
1389+
url = f"{self.base_url}/v1/requests/{request_id}/progress"
1390+
response = requests.get(url, headers=self._headers)
1391+
if response.status_code != 200:
1392+
raise RuntimeError(
1393+
f"Failed to get progress, detail: {_get_error_string(response)}"
1394+
)
1395+
response_json = response.json()
1396+
return response_json
1397+
13881398
def abort_cluster(self):
13891399
url = f"{self.base_url}/v1/clusters"
13901400
response = requests.delete(url, headers=self._headers)

xinference/core/image_interface.py

Lines changed: 76 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
import io
1717
import logging
1818
import os
19+
import threading
20+
import time
21+
import uuid
1922
from typing import Dict, List, Optional, Union
2023

2124
import gradio as gr
@@ -84,6 +87,7 @@ def text_generate_image(
8487
num_inference_steps: int,
8588
negative_prompt: Optional[str] = None,
8689
sampler_name: Optional[str] = None,
90+
progress=gr.Progress(),
8791
) -> PIL.Image.Image:
8892
from ..client import RESTfulClient
8993

@@ -99,19 +103,43 @@ def text_generate_image(
99103
)
100104
sampler_name = None if sampler_name == "default" else sampler_name
101105

102-
response = model.text_to_image(
103-
prompt=prompt,
104-
n=n,
105-
size=size,
106-
num_inference_steps=num_inference_steps,
107-
guidance_scale=guidance_scale,
108-
negative_prompt=negative_prompt,
109-
sampler_name=sampler_name,
110-
response_format="b64_json",
111-
)
106+
response = None
107+
exc = None
108+
request_id = str(uuid.uuid4())
109+
110+
def run_in_thread():
111+
nonlocal exc, response
112+
try:
113+
response = model.text_to_image(
114+
request_id=request_id,
115+
prompt=prompt,
116+
n=n,
117+
size=size,
118+
num_inference_steps=num_inference_steps,
119+
guidance_scale=guidance_scale,
120+
negative_prompt=negative_prompt,
121+
sampler_name=sampler_name,
122+
response_format="b64_json",
123+
)
124+
except Exception as e:
125+
exc = e
126+
127+
t = threading.Thread(target=run_in_thread)
128+
t.start()
129+
while t.is_alive():
130+
try:
131+
cur_progress = client.get_progress(request_id)["progress"]
132+
except (KeyError, RuntimeError):
133+
cur_progress = 0.0
134+
135+
progress(cur_progress, desc="Generating images")
136+
time.sleep(1)
137+
138+
if exc:
139+
raise exc
112140

113141
images = []
114-
for image_dict in response["data"]:
142+
for image_dict in response["data"]: # type: ignore
115143
assert image_dict["b64_json"] is not None
116144
image_data = base64.b64decode(image_dict["b64_json"])
117145
image = PIL.Image.open(io.BytesIO(image_data))
@@ -184,6 +212,7 @@ def image_generate_image(
184212
num_inference_steps: int,
185213
padding_image_to_multiple: int,
186214
sampler_name: Optional[str] = None,
215+
progress=gr.Progress(),
187216
) -> PIL.Image.Image:
188217
from ..client import RESTfulClient
189218

@@ -205,20 +234,44 @@ def image_generate_image(
205234
bio = io.BytesIO()
206235
image.save(bio, format="png")
207236

208-
response = model.image_to_image(
209-
prompt=prompt,
210-
negative_prompt=negative_prompt,
211-
n=n,
212-
image=bio.getvalue(),
213-
size=size,
214-
response_format="b64_json",
215-
num_inference_steps=num_inference_steps,
216-
padding_image_to_multiple=padding_image_to_multiple,
217-
sampler_name=sampler_name,
218-
)
237+
response = None
238+
exc = None
239+
request_id = str(uuid.uuid4())
240+
241+
def run_in_thread():
242+
nonlocal exc, response
243+
try:
244+
response = model.image_to_image(
245+
request_id=request_id,
246+
prompt=prompt,
247+
negative_prompt=negative_prompt,
248+
n=n,
249+
image=bio.getvalue(),
250+
size=size,
251+
response_format="b64_json",
252+
num_inference_steps=num_inference_steps,
253+
padding_image_to_multiple=padding_image_to_multiple,
254+
sampler_name=sampler_name,
255+
)
256+
except Exception as e:
257+
exc = e
258+
259+
t = threading.Thread(target=run_in_thread)
260+
t.start()
261+
while t.is_alive():
262+
try:
263+
cur_progress = client.get_progress(request_id)["progress"]
264+
except (KeyError, RuntimeError):
265+
cur_progress = 0.0
266+
267+
progress(cur_progress, desc="Generating images")
268+
time.sleep(1)
269+
270+
if exc:
271+
raise exc
219272

220273
images = []
221-
for image_dict in response["data"]:
274+
for image_dict in response["data"]: # type: ignore
222275
assert image_dict["b64_json"] is not None
223276
image_data = base64.b64decode(image_dict["b64_json"])
224277
image = PIL.Image.open(io.BytesIO(image_data))

xinference/core/model.py

Lines changed: 80 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from ..constants import XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
4545

4646
if TYPE_CHECKING:
47+
from .progress_tracker import ProgressTrackerActor
4748
from .worker import WorkerActor
4849
from ..model.llm.core import LLM
4950
from ..model.core import ModelDescription
@@ -177,6 +178,7 @@ async def __pre_destroy__(self):
177178

178179
def __init__(
179180
self,
181+
supervisor_address: str,
180182
worker_address: str,
181183
model: "LLM",
182184
model_description: Optional["ModelDescription"] = None,
@@ -188,6 +190,7 @@ def __init__(
188190
from ..model.llm.transformers.core import PytorchModel
189191
from ..model.llm.vllm.core import VLLMModel
190192

193+
self._supervisor_address = supervisor_address
191194
self._worker_address = worker_address
192195
self._model = model
193196
self._model_description = (
@@ -205,6 +208,7 @@ def __init__(
205208
else asyncio.locks.Lock()
206209
)
207210
self._worker_ref = None
211+
self._progress_tracker_ref = None
208212
self._serve_count = 0
209213
self._metrics_labels = {
210214
"type": self._model_description.get("model_type", "unknown"),
@@ -275,6 +279,28 @@ async def _get_worker_ref(self) -> xo.ActorRefType["WorkerActor"]:
275279
)
276280
return self._worker_ref
277281

282+
async def _get_progress_tracker_ref(
283+
self,
284+
) -> xo.ActorRefType["ProgressTrackerActor"]:
285+
from .progress_tracker import ProgressTrackerActor
286+
287+
if self._progress_tracker_ref is None:
288+
self._progress_tracker_ref = await xo.actor_ref(
289+
address=self._supervisor_address, uid=ProgressTrackerActor.default_uid()
290+
)
291+
return self._progress_tracker_ref
292+
293+
async def _get_progressor(self, request_id: str):
294+
from .progress_tracker import Progressor
295+
296+
progressor = Progressor(
297+
request_id,
298+
await self._get_progress_tracker_ref(),
299+
asyncio.get_running_loop(),
300+
)
301+
await progressor.start()
302+
return progressor
303+
278304
def is_vllm_backend(self) -> bool:
279305
from ..model.llm.vllm.core import VLLMModel
280306

@@ -732,17 +758,20 @@ async def text_to_image(
732758
*args,
733759
**kwargs,
734760
):
735-
kwargs.pop("request_id", None)
736761
if hasattr(self._model, "text_to_image"):
737-
return await self._call_wrapper_json(
738-
self._model.text_to_image,
739-
prompt,
740-
n,
741-
size,
742-
response_format,
743-
*args,
744-
**kwargs,
762+
progressor = kwargs["progressor"] = await self._get_progressor(
763+
kwargs.pop("request_id", None)
745764
)
765+
with progressor:
766+
return await self._call_wrapper_json(
767+
self._model.text_to_image,
768+
prompt,
769+
n,
770+
size,
771+
response_format,
772+
*args,
773+
**kwargs,
774+
)
746775
raise AttributeError(
747776
f"Model {self._model.model_spec} is not for creating image."
748777
)
@@ -753,12 +782,15 @@ async def txt2img(
753782
self,
754783
**kwargs,
755784
):
756-
kwargs.pop("request_id", None)
757785
if hasattr(self._model, "txt2img"):
758-
return await self._call_wrapper_json(
759-
self._model.txt2img,
760-
**kwargs,
786+
progressor = kwargs["progressor"] = await self._get_progressor(
787+
kwargs.pop("request_id", None)
761788
)
789+
with progressor:
790+
return await self._call_wrapper_json(
791+
self._model.txt2img,
792+
**kwargs,
793+
)
762794
raise AttributeError(f"Model {self._model.model_spec} is not for txt2img.")
763795

764796
@log_async(
@@ -776,19 +808,22 @@ async def image_to_image(
776808
*args,
777809
**kwargs,
778810
):
779-
kwargs.pop("request_id", None)
780811
kwargs["negative_prompt"] = negative_prompt
781812
if hasattr(self._model, "image_to_image"):
782-
return await self._call_wrapper_json(
783-
self._model.image_to_image,
784-
image,
785-
prompt,
786-
n,
787-
size,
788-
response_format,
789-
*args,
790-
**kwargs,
813+
progressor = kwargs["progressor"] = await self._get_progressor(
814+
kwargs.pop("request_id", None)
791815
)
816+
with progressor:
817+
return await self._call_wrapper_json(
818+
self._model.image_to_image,
819+
image,
820+
prompt,
821+
n,
822+
size,
823+
response_format,
824+
*args,
825+
**kwargs,
826+
)
792827
raise AttributeError(
793828
f"Model {self._model.model_spec} is not for creating image."
794829
)
@@ -799,12 +834,15 @@ async def img2img(
799834
self,
800835
**kwargs,
801836
):
802-
kwargs.pop("request_id", None)
803837
if hasattr(self._model, "img2img"):
804-
return await self._call_wrapper_json(
805-
self._model.img2img,
806-
**kwargs,
838+
progressor = kwargs["progressor"] = await self._get_progressor(
839+
kwargs.pop("request_id", None)
807840
)
841+
with progressor:
842+
return await self._call_wrapper_json(
843+
self._model.img2img,
844+
**kwargs,
845+
)
808846
raise AttributeError(f"Model {self._model.model_spec} is not for img2img.")
809847

810848
@log_async(
@@ -823,20 +861,23 @@ async def inpainting(
823861
*args,
824862
**kwargs,
825863
):
826-
kwargs.pop("request_id", None)
864+
kwargs["negative_prompt"] = negative_prompt
827865
if hasattr(self._model, "inpainting"):
828-
return await self._call_wrapper_json(
829-
self._model.inpainting,
830-
image,
831-
mask_image,
832-
prompt,
833-
negative_prompt,
834-
n,
835-
size,
836-
response_format,
837-
*args,
838-
**kwargs,
866+
progressor = kwargs["progressor"] = await self._get_progressor(
867+
kwargs.pop("request_id", None)
839868
)
869+
with progressor:
870+
return await self._call_wrapper_json(
871+
self._model.inpainting,
872+
image,
873+
mask_image,
874+
prompt,
875+
n,
876+
size,
877+
response_format,
878+
*args,
879+
**kwargs,
880+
)
840881
raise AttributeError(
841882
f"Model {self._model.model_spec} is not for creating image."
842883
)

0 commit comments

Comments
 (0)