Skip to content
This repository was archived by the owner on Sep 23, 2025. It is now read-only.

Commit 0b44ac4

Browse files
authored
[Finetune] Fix evaluation (#252)
* fix evaluation * update doc * fix comments Signed-off-by: minmingzhu <minming.zhu@intel.com>
1 parent fc45da8 commit 0b44ac4

File tree

5 files changed

+271
-113
lines changed

5 files changed

+271
-113
lines changed

docs/finetune_parameters.md

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,25 @@ The following are the parameters supported in the finetuning workflow.
1919

2020

2121
## Dataset Parameters
22-
|Configuration Name| Default|Meaning|
23-
|-|-|-|
24-
|train_file|examples/data/sample_finetune_data.jsonl|A json file containing the training data.|
25-
|validation_file|None|A json file containing the validation data.|
26-
|validation_split_percentage|5|The percentage of the train set used as validation set in case there's no validation split|
27-
|preprocessing_num_workers|None|The number of processes to use for the preprocessing.|
28-
|max_length|512|Padding sequential data to max length of a batch|
29-
|group|True|Whether to concatenate the sentence for more efficient training|
30-
|block_size|512|The block size of concatenated sentence|
31-
|shuffle|False|Whether shuffle the data at every epoch|
22+
| Configuration Name | Default| Meaning |
23+
|-----------------------------|-|------------------------------------------------------------------------------------------------------------------------------------------|
24+
| train_file |examples/data/sample_finetune_data.jsonl| A json file containing the training data. |
25+
| validation_file |None| A json file containing the validation data. |
26+
| validation_split_percentage |5| The percentage of the train set used as validation set in case there's no validation split |
27+
| preprocessing_num_workers |None| The number of processes to use for the preprocessing. |
28+
| max_length |512| Padding sequential data to max length of a batch |
29+
| group |True| Whether to concatenate the sentence for more efficient training |
30+
| block_size |512| The block size of concatenated sentence |
31+
| shuffle |False| Whether shuffle the data at every epoch |
32+
| max_source_length |384| The maximum total input sequence length after tokenization. Sequences longer than this will be truncated, sequences shorter will be padded. |
33+
| padding_side |right| The side on which the model should have padding applied. Should be selected between ['right', 'left']. |
34+
| truncation_side |right| The side on which the model should have truncation applied. Should be selected between ['right', 'left']. |
35+
| max_seq_length |max_length| The maximum total input sequence length after tokenization. |
36+
| truncation |True| truncation strategy. Should be selected between ['only_first', 'only_second', 'longest_first/True', 'do_not_truncate/False']. |
37+
| padding |True| padding strategy. Should be selected between ['longest/True', 'do_not_pad/False', 'max_length']
38+
| mask_input |True| mask the input part in lables |
39+
| mask_response |True| mask the response part in lables |
40+
| data_preprocess_type |neural_chat| The type of the encode input |
3241

3342

3443
## Training Parameters
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
#
2+
# Copyright 2023 The LLM-on-Ray Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
import copy
18+
import re
19+
from itertools import chain
20+
21+
import torch
22+
23+
IGNORE_INDEX = -100
24+
25+
26+
class DataProcessor:
27+
# We used the following prompts for fine-tuning the Alpaca model. You can find reference doc form this URL(https://github.com/tatsu-lab/stanford_alpaca/blob/main/README.md#data-release)
28+
def __init__(self, config, tokenizer):
29+
self.tokenizer = tokenizer
30+
self.end = tokenizer.eos_token
31+
self.intro = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
32+
self.instruction = "### Instruction:\n"
33+
self.input = "### Input:\n"
34+
self.response = "### Response:\n"
35+
self.padding_side = config["Dataset"].get("padding_side", "right")
36+
self.truncation_side = config["Dataset"].get("truncation_side", "right")
37+
self.max_length = self.max_seq_length = config["Dataset"].get("max_length", 512)
38+
self.max_source_length = config["Dataset"].get("max_source_length", 384)
39+
self.truncation = config["Dataset"].get("truncation", True)
40+
self.padding = config["Dataset"].get("padding", True)
41+
self.mask_input = config["Dataset"].get("mask_input", True)
42+
self.mask_response = config["Dataset"].get("mask_response", True)
43+
44+
def make_prompt(self, examples):
45+
prompts = {}
46+
prompts["prompt_sources"] = []
47+
prompts["prompt_targets"] = []
48+
for rec in examples:
49+
instruction = rec["instruction"]
50+
response = rec["response"]
51+
context = rec.get("context")
52+
if not instruction:
53+
raise ValueError(f"Expected an instruction in: {rec}")
54+
if not response:
55+
raise ValueError(f"Expected a response in: {rec}")
56+
if context:
57+
prompt = (
58+
self.intro
59+
+ self.end
60+
+ "\n"
61+
+ self.instruction
62+
+ instruction
63+
+ self.input
64+
+ context
65+
+ self.end
66+
+ "\n"
67+
+ self.response
68+
)
69+
prompts["prompt_sources"].append(prompt)
70+
else:
71+
prompt = (
72+
self.intro
73+
+ self.end
74+
+ "\n"
75+
+ self.instruction
76+
+ instruction
77+
+ self.end
78+
+ "\n"
79+
+ self.response
80+
)
81+
prompts["prompt_sources"].append(prompt)
82+
prompt_response = response + self.end
83+
prompts["prompt_targets"].append(prompt_response)
84+
return prompts
85+
86+
def __truncate_sequences(self, sequences, max_length):
87+
"""
88+
Copied from https://github.com/intel/intel-extension-for-transformers/blob/ae54f698b73a66e5729427cb19f69c33e1a5c34d/intel_extension_for_transformers/transformers/llm/finetuning/data_utils.py#L40
89+
"""
90+
words_to_cut = sum(list(map(len, sequences))) - max_length
91+
if words_to_cut <= 0:
92+
return sequences
93+
94+
while words_to_cut > 0 and len(sequences) > 0:
95+
words_to_cut -= len(sequences[0])
96+
sequences = sequences[1:]
97+
return sequences
98+
99+
def tokenize_by_neural_chat(self, examples):
100+
"""
101+
Copied from https://github.com/intel/intel-extension-for-transformers/blob/ae54f698b73a66e5729427cb19f69c33e1a5c34d/intel_extension_for_transformers/transformers/llm/finetuning/data_utils.py#L225
102+
The only differences are:
103+
- using our own prompt style
104+
- add left or right padding and truncation
105+
- add mask_input and mask_response
106+
"""
107+
keys = list(examples.data.keys())
108+
if len(keys) != 2:
109+
raise ValueError("Unsupported dataset format")
110+
assistant_tokens = self.tokenizer.tokenize(self.response)
111+
header = self.intro + self.end + "\n"
112+
113+
examples["input_ids"] = []
114+
examples["labels"] = []
115+
examples["attention_mask"] = []
116+
for instruction, response in zip(examples[keys[0]], examples[keys[1]]):
117+
convs = re.findall(
118+
r"{0}.*?{2}|{1}.*?{2}".format(self.instruction, self.response, self.end),
119+
instruction,
120+
re.DOTALL,
121+
)
122+
convs_tokens = [
123+
self.tokenizer.tokenize(conv) + self.tokenizer.tokenize("\n") for conv in convs
124+
]
125+
header_tokens = self.tokenizer.tokenize(header) + self.tokenizer.tokenize("\n")
126+
max_input = self.max_source_length - len(header_tokens) - len(assistant_tokens)
127+
truncated_convs = self.__truncate_sequences(convs_tokens, max_input)
128+
if len(truncated_convs) == 0:
129+
truncated_convs = [convs_tokens[-1][: max_input - 3] + convs_tokens[-1][-3:]]
130+
131+
prompt_tokens = [header_tokens] + truncated_convs + [assistant_tokens]
132+
prompt_ids = [
133+
self.tokenizer.convert_tokens_to_ids(prompt_token) for prompt_token in prompt_tokens
134+
]
135+
prompt_ids = list(chain(*prompt_ids))
136+
137+
resp_ids = self.tokenizer.convert_tokens_to_ids(
138+
self.tokenizer.tokenize(response.strip())
139+
)
140+
# keep last and eos_id
141+
max_resp = self.max_seq_length - len(prompt_ids) - 1
142+
143+
# truncating response
144+
if len(resp_ids) > max_resp:
145+
if self.truncation_side == "right":
146+
resp_ids = resp_ids[: max_resp - 1] + resp_ids[-1:]
147+
else:
148+
resp_ids = resp_ids[-max_resp:]
149+
150+
# masking
151+
input_ids = prompt_ids + resp_ids + [self.tokenizer.eos_token_id]
152+
if self.mask_input:
153+
labels = [IGNORE_INDEX] * len(prompt_ids) + resp_ids + [self.tokenizer.eos_token_id]
154+
elif self.mask_response:
155+
labels = prompt_ids + [IGNORE_INDEX] * len(resp_ids) + [self.tokenizer.eos_token_id]
156+
else:
157+
labels = input_ids
158+
159+
# padding
160+
input_len = len(input_ids)
161+
pad_len = self.max_seq_length - input_len
162+
if self.padding_side == "right":
163+
input_ids = input_ids + [self.tokenizer.eos_token_id] * pad_len
164+
labels = labels + [IGNORE_INDEX] * pad_len
165+
attention_mask = [1] * input_len + [0] * pad_len
166+
else:
167+
input_ids = [self.tokenizer.eos_token_id] * pad_len + input_ids
168+
labels = [IGNORE_INDEX] * pad_len + labels
169+
attention_mask = [0] * pad_len + [1] * input_len
170+
171+
assert len(input_ids) == self.max_seq_length
172+
assert len(prompt_ids) <= self.max_source_length
173+
assert len(labels) == len(input_ids) == len(attention_mask)
174+
175+
examples["input_ids"].append(torch.tensor(input_ids))
176+
examples["labels"].append(labels)
177+
examples["attention_mask"].append(attention_mask)
178+
179+
return examples
180+
181+
def tokenize(self, examples):
182+
keys = list(examples.data.keys())
183+
if len(keys) != 2:
184+
raise ValueError("Unsupported dataset format")
185+
186+
examples["input_ids"] = []
187+
examples["labels"] = []
188+
examples["attention_mask"] = []
189+
for s, t in zip(examples[keys[0]], examples[keys[1]]):
190+
results = self.tokenizer(
191+
s + t,
192+
padding=self.padding,
193+
truncation=self.truncation,
194+
return_tensors=None,
195+
max_length=self.max_length,
196+
)
197+
198+
input_ids = results["input_ids"]
199+
input_len = len(input_ids)
200+
labels = copy.deepcopy(input_ids)
201+
if self.mask_input or self.mask_response:
202+
sources_tokenized = self.tokenizer(
203+
s,
204+
padding=False,
205+
truncation=True,
206+
return_tensors=None,
207+
max_length=self.max_length,
208+
)
209+
input_id_len = len(sources_tokenized["input_ids"])
210+
# mask input
211+
if self.mask_input:
212+
labels[:input_id_len] = [IGNORE_INDEX] * input_id_len
213+
# mask response
214+
if self.mask_response:
215+
labels[input_id_len:input_len] = [IGNORE_INDEX] * (input_len - input_id_len)
216+
217+
examples["input_ids"].append(results["input_ids"])
218+
examples["labels"].append(labels)
219+
examples["attention_mask"].append(results["attention_mask"])
220+
return examples

llm_on_ray/finetune/finetune.py

Lines changed: 23 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818

1919
import os
2020
import argparse
21+
import re
2122
import sys
23+
import copy
24+
2225
from typing import Any, Dict, Union, Optional
2326

2427
from itertools import chain
@@ -37,9 +40,8 @@
3740
from pydantic_yaml import parse_yaml_raw_as
3841

3942
from llm_on_ray import common
40-
from llm_on_ray.finetune import template
43+
from llm_on_ray.finetune.data_process import DataProcessor
4144
from llm_on_ray.finetune.finetune_config import FinetuneConfig
42-
from importlib import util
4345

4446

4547
def adapt_transformers_to_device(config: Dict):
@@ -140,7 +142,13 @@ def load_tokenizer(config: Dict):
140142
else:
141143
tokenizer_name = config["General"]["base_model"]
142144
load_config = config["General"].get("config", {})
143-
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name, **load_config)
145+
# default padding side is right
146+
padding_side = config["Dataset"].get("padding_side", "right")
147+
# default truncation side is right
148+
truncation_side = config["Dataset"].get("truncation_side", "right")
149+
tokenizer = transformers.AutoTokenizer.from_pretrained(
150+
tokenizer_name, padding_side=padding_side, truncation_side=truncation_side, **load_config
151+
)
144152
return tokenizer
145153

146154

@@ -195,50 +203,27 @@ def local_load(name, **load_config):
195203

196204

197205
def tokenize_dataset(config: Dict, tokenizer, dataset):
198-
max_length = config["Dataset"].get("max_length", 512)
199206
group = config["Dataset"].get("group", True)
200207
block_size = config["Dataset"].get("block_size", 512)
201208
tokenizer.pad_token = tokenizer.eos_token
202209

203-
if isinstance(dataset, datasets.Dataset):
204-
column_names = dataset.column_names
205-
206-
if isinstance(dataset, datasets.DatasetDict):
207-
column_names = dataset["train"].column_names
208-
209-
if column_names and template.TEXT_COLUMN_NAME not in column_names:
210-
211-
def prompt(rec):
212-
instruction = rec["instruction"]
213-
response = rec["response"]
214-
context = rec.get("context")
215-
if not instruction:
216-
raise ValueError(f"Expected an instruction in: {rec}")
217-
if not response:
218-
raise ValueError(f"Expected a response in: {rec}")
219-
if context:
220-
rec["text"] = template.PROMPT_WITH_INPUT_FORMAT.format(
221-
instruction=instruction, response=response, input=context
222-
)
223-
else:
224-
rec["text"] = template.PROMPT_NO_INPUT_FORMAT.format(
225-
instruction=instruction, response=response
226-
)
227-
return rec
210+
processor = DataProcessor(config, tokenizer)
228211

229-
dataset = dataset.map(
230-
prompt,
231-
load_from_cache_file=False,
232-
desc="Prompt",
233-
)
234-
column_names += [template.TEXT_COLUMN_NAME]
212+
for key in dataset:
213+
prompts = processor.make_prompt(dataset[key])
214+
dataset[key] = datasets.Dataset.from_dict(prompts)
235215

236-
def tokenize_function(examples):
237-
return tokenizer(examples[template.TEXT_COLUMN_NAME], max_length=max_length)
216+
column_names = list(dataset["train"].features)
217+
tokenize_fn = (
218+
processor.tokenize_by_neural_chat
219+
if config["Dataset"].get("data_preprocess_type", "neural_chat") == "neural_chat"
220+
else processor.tokenize
221+
)
238222

239223
tokenized_dataset = dataset.map(
240-
tokenize_function,
224+
tokenize_fn,
241225
remove_columns=column_names,
226+
batched=True,
242227
load_from_cache_file=False,
243228
desc="Tokenize dataset",
244229
)
@@ -258,7 +243,6 @@ def group_texts(examples):
258243
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
259244
for k, t in concatenated_examples.items()
260245
}
261-
result["labels"] = result["input_ids"].copy()
262246
return result
263247

264248
tokenized_dataset = tokenized_dataset.map(

llm_on_ray/finetune/finetune_config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,15 @@ class Dataset(BaseModel):
7171
group: bool = True
7272
block_size: int = 512
7373
shuffle: bool = False
74+
max_source_length: int = 384
75+
padding_side: str = "right"
76+
truncation_side: str = "right"
77+
max_seq_length: int = 512
78+
truncation: bool = True
79+
padding: bool = True
80+
mask_input: bool = True
81+
mask_response: bool = True
82+
data_preprocess_type: str = "neural_chat"
7483

7584

7685
class RayResourceConfig(BaseModel):

0 commit comments

Comments
 (0)