|
3 | 3 | import habana_frameworks.torch.core as htcore |
4 | 4 |
|
5 | 5 | from loguru import logger |
6 | | -from typing import Dict, Union |
| 6 | +from typing import Dict |
7 | 7 | from text_generation_server.pb.generate_pb2 import GrammarType |
8 | 8 |
|
9 | 9 | from outlines.fsm.fsm import RegexFSM |
|
13 | 13 | import time |
14 | 14 |
|
15 | 15 | from transformers import ( |
16 | | - LogitsWarper, |
17 | 16 | LogitsProcessor, |
18 | 17 | TemperatureLogitsWarper, |
19 | 18 | TopKLogitsWarper, |
@@ -191,7 +190,7 @@ def filter(self, indices): |
191 | 190 |
|
192 | 191 | class HeterogeneousTemperatureLogitsWarper: |
193 | 192 | r""" |
194 | | - [`LogitsWarper`] for temperature (exponential scaling output probability distribution). |
| 193 | + [`LogitsProcessor`] for temperature (exponential scaling output probability distribution). |
195 | 194 | This version allows for a separate value for each sample and runs inplace when possible. |
196 | 195 | It doesn't validate inputs. |
197 | 196 |
|
@@ -220,7 +219,7 @@ def filter(self, indices): |
220 | 219 | return None |
221 | 220 |
|
222 | 221 |
|
223 | | -class HeterogeneousTopPLogitsWarper(LogitsWarper): |
| 222 | +class HeterogeneousTopPLogitsWarper(LogitsProcessor): |
224 | 223 | """ |
225 | 224 | [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. |
226 | 225 | This version allows for a separate value for each sample and runs inplace when possible. |
@@ -279,9 +278,9 @@ def filter(self, indices): |
279 | 278 | return None |
280 | 279 |
|
281 | 280 |
|
282 | | -class HeterogeneousTopKLogitsWarper(LogitsWarper): |
| 281 | +class HeterogeneousTopKLogitsWarper(LogitsProcessor): |
283 | 282 | r""" |
284 | | - [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. |
| 283 | + [`LogitsProcessor`] that performs top-k, i.e. restricting to the k highest probability elements. |
285 | 284 | This version allows for a separate value for each sample and runs inplace when possible. |
286 | 285 | It doesn't validate inputs. |
287 | 286 |
|
@@ -360,9 +359,9 @@ def filter(self, indices): |
360 | 359 | return None |
361 | 360 |
|
362 | 361 |
|
363 | | -class HeterogeneousTypicalLogitsWarper(LogitsWarper): |
| 362 | +class HeterogeneousTypicalLogitsWarper(LogitsProcessor): |
364 | 363 | r""" |
365 | | - [`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language |
| 364 | + [`LogitsProcessor`] that performs typical decoding. See [Typical Decoding for Natural Language |
366 | 365 | Generation](https://arxiv.org/abs/2202.00666) for more information. |
367 | 366 | This version allows for a separate value for each sample and runs inplace when possible. |
368 | 367 | It doesn't validate inputs. |
@@ -454,13 +453,13 @@ class HeterogeneousProcessorWrapper(LogitsProcessor): |
454 | 453 | r""" |
455 | 454 | A wrapper for logit warpers or processors without heterogeneous parameter support. |
456 | 455 | Args: |
457 | | - processors (`Dict[int, Union[LogitsProcessor, LogitsWarper]]`): |
| 456 | + processors (`Dict[int, LogitsProcessor]`): |
458 | 457 | A mapping of sample indices to logit warpers or processors, to be run sequentially. |
459 | 458 | """ |
460 | 459 |
|
461 | 460 | def __init__( |
462 | 461 | self, |
463 | | - processors: Dict[int, Union[LogitsProcessor, LogitsWarper]], |
| 462 | + processors: Dict[int, LogitsProcessor], |
464 | 463 | ): |
465 | 464 | self.processors = processors |
466 | 465 |
|
|
0 commit comments