Skip to content
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ install_requires =
accelerate>=0.22.0
huggingface-hub>=0.11.1,<1.0.0
tokenizers>=0.13.3
transformers>=4.32.0,<4.35.0 # if you change this, please also change version assert in petals/__init__.py
transformers>=4.35.0 # 4.35.0 is the minimum that contains modeling_attn_mask_utils.py
speedtest-cli==2.1.3
pydantic>=1.10,<2.0 # 2.0 is incompatible with hivemind yet
hivemind==1.1.10.post2
Expand Down
8 changes: 1 addition & 7 deletions src/petals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,7 @@
from petals.utils import *
from petals.utils.logging import initialize_logs as _initialize_logs

__version__ = "2.3.0.dev1"


if not os.getenv("PETALS_IGNORE_DEPENDENCY_VERSION"):
assert (
version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("4.35.0")
), "Please install a proper transformers version: pip install transformers>=4.32.0,<4.35.0"
__version__ = "2.3.0.dev2"


def _override_bfloat16_mode_default():
Expand Down
3 changes: 3 additions & 0 deletions src/petals/client/remote_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class _SkipTokensMixin:
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> dict:
input_ids = input_ids[:, _skipped_tokens.get() :]
_skipped_tokens.set(0)
if "past_key_values" in kwargs:
if kwargs["past_key_values"][0][0].shape == torch.Size([0]):
kwargs["past_key_values"] = None
return super().prepare_inputs_for_generation(input_ids, **kwargs)


Expand Down
10 changes: 9 additions & 1 deletion src/petals/models/bloom/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Optional, Tuple

import torch
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel, build_alibi_tensor


Expand All @@ -26,7 +27,14 @@ def forward(
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
if alibi is None:
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
attention_mask = BloomModel._prepare_attn_mask(None, attention_mask, (batch_size, seq_length), past_length)

attention_mask = _prepare_4d_causal_attention_mask(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=hidden_states,
past_key_values_length=past_length,
)
attention_mask = attention_mask.bool()
return super().forward(
hidden_states, *args, attention_mask=attention_mask, alibi=alibi, layer_past=layer_past, **kwargs
)
9 changes: 8 additions & 1 deletion src/petals/models/falcon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.falcon.modeling_falcon import (
FalconAttention,
FalconConfig,
Expand Down Expand Up @@ -418,7 +419,13 @@ def forward(
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
if alibi is None and self.config.alibi:
alibi = build_alibi_tensor(attention_mask, num_heads=self.num_heads, dtype=hidden_states.dtype)
attention_mask = FalconModel._prepare_attn_mask(attention_mask, (batch_size, seq_length), past_length)

attention_mask = _prepare_4d_causal_attention_mask(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=hidden_states,
past_key_values_length=past_length,
)

outputs = super().forward(
hidden_states,
Expand Down
12 changes: 8 additions & 4 deletions src/petals/models/llama/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaConfig,
Expand Down Expand Up @@ -84,8 +85,8 @@ def forward(
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos = cos[:, :, kv_seq_len - q_len :]
sin = sin[:, :, kv_seq_len - q_len :]
cos = cos[kv_seq_len - q_len :]
sin = sin[kv_seq_len - q_len :]

if q_len == 1 and torch.is_inference_mode_enabled() and hidden_states.device.type == "cuda":
query_states, key_states = self._optimized_apply_rotary(query_states, key_states, cos, sin)
Expand Down Expand Up @@ -244,8 +245,11 @@ def forward(
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
attention_mask = LlamaModel._prepare_decoder_attention_mask(
None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask=attention_mask,
input_shape=(batch_size, seq_length),
inputs_embeds=hidden_states,
past_key_values_length=past_key_values_length,
)

outputs = super().forward(
Expand Down
5 changes: 3 additions & 2 deletions tests/test_optimized_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
import torch
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel

Expand Down Expand Up @@ -131,8 +132,8 @@ def forward(
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
)
attention_mask = LlamaModel._prepare_decoder_attention_mask(
None, attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)

outputs = super().forward(
Expand Down