-
Notifications
You must be signed in to change notification settings - Fork 32
Modified BertClassifier class, new BertEmbeddingGenerator class #27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
DevBerge
wants to merge
2
commits into
mim-solutions:main
Choose a base branch
from
DevBerge:embedder
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import torch | ||
|
||
from belt_nlp.bert_with_pooling import BertClassifier | ||
from typing import Optional, List | ||
from torch import Tensor | ||
from torch.nn import Module | ||
from transformers import PreTrainedTokenizerBase, BatchEncoding | ||
from belt_nlp.splitting import transform_list_of_texts | ||
|
||
|
||
class BertEmbeddingGenerator(BertClassifier): | ||
def __init__( | ||
self, | ||
chunk_size: int, | ||
stride: int, | ||
minimal_chunk_length: int, | ||
pooling_strategy: str = "mean", | ||
maximal_text_length: Optional[int] = None, | ||
tokenizer: Optional[PreTrainedTokenizerBase] = None, | ||
neural_network: Optional[Module] = None, | ||
pretrained_model_name_or_path: Optional[str] = "bert-base-uncased", | ||
trust_remote_code: Optional[bool] = False, | ||
device: str = "cuda:0", | ||
many_gpus: bool = False, | ||
): | ||
|
||
super().__init__( | ||
tokenizer=tokenizer, | ||
neural_network=neural_network, | ||
pretrained_model_name_or_path=pretrained_model_name_or_path, | ||
trust_remote_code=trust_remote_code, | ||
device=device, | ||
many_gpus=many_gpus | ||
) | ||
|
||
self.chunk_size = chunk_size | ||
self.stride = stride | ||
self.minimal_chunk_length = minimal_chunk_length | ||
self.maximal_text_length = maximal_text_length | ||
self.pooling_strategy = pooling_strategy | ||
if pooling_strategy not in ["mean", "max"]: | ||
raise ValueError("Unknown pooling strategy!") | ||
|
||
self.collate_fn = BertEmbeddingGenerator.collate_fn_pooled_tokens | ||
|
||
def _tokenize(self, texts: list[str]) -> BatchEncoding: | ||
""" | ||
Transforms list of N texts to the BatchEncoding, that is the dictionary with the following keys: | ||
- input_ids - List of N tensors of the size K(i) x 512 of token ids. | ||
K(i) is the number of chunks of the text i. | ||
Each element of the list is stacked Tensor for encoding of each chunk. | ||
Values of the tensor are integers. | ||
- attention_mask - List of N tensors of the size K(i) x 512 of attention masks. | ||
K(i) is the number of chunks of the text i. | ||
Each element of the list is stacked Tensor for encoding of each chunk. | ||
Values of the tensor are booleans. | ||
|
||
These lists of tensors cannot be stacked into one tensor, | ||
because each text can be divided into different number of chunks. | ||
""" | ||
tokens = transform_list_of_texts( | ||
texts, self.tokenizer, self.chunk_size, self.stride, self.minimal_chunk_length, self.maximal_text_length | ||
) | ||
return tokens | ||
|
||
def get_embeddings(self, documents: List[str]) -> List[Tensor]: | ||
|
||
all_embeddings = [] | ||
for document in documents: | ||
tokens = self._tokenize([document]) | ||
|
||
input_ids, attention_masks = tokens["input_ids"], tokens["attention_mask"] | ||
|
||
# Process each document's chunks and pool their embeddings | ||
document_embedding = self.process_and_pool_chunks((input_ids, attention_masks)) | ||
all_embeddings.append(document_embedding) | ||
|
||
return torch.stack(all_embeddings).tolist() | ||
|
||
def process_and_pool_chunks(self, batch: tuple[Tensor]): | ||
input_ids = batch[0][0].to(self.device) | ||
attention_mask = batch[1][0].to(self.device) | ||
|
||
model_output = self.neural_network(input_ids, attention_mask=attention_mask, return_embeddings=True) | ||
sequence_output = model_output[:, 0, :] # Taking CLS token as my pretrained model performs better with it | ||
|
||
if self.pooling_strategy == "mean": | ||
pooled_output = torch.mean(sequence_output, dim=0).detach().cpu() | ||
elif self.pooling_strategy == "max": | ||
pooled_output = torch.max(sequence_output, dim=0).values | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am curious about this part:
|
||
else: | ||
raise ValueError("Unknown pooling strategy!") | ||
|
||
return pooled_output | ||
|
||
def _evaluate_single_batch(self, batch: tuple[Tensor]) -> Tensor: | ||
pass | ||
|
||
@staticmethod | ||
def collate_fn_pooled_tokens(data): | ||
input_ids = [data[i][0] for i in range(len(data))] | ||
attention_mask = [data[i][1] for i in range(len(data))] | ||
if len(data[0]) == 2: | ||
collated = [input_ids, attention_mask] | ||
else: | ||
labels = Tensor([data[i][2] for i in range(len(data))]) | ||
collated = [input_ids, attention_mask, labels] | ||
return collated | ||
|
||
|
||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This inheritance from BertClassifier is confusing - this is not the classifier. I suggest refactoring it using composition over inheritance principle.