-
Notifications
You must be signed in to change notification settings - Fork 388
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
422 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
from typing import Iterable, Dict, List, Union, Any, Callable | ||
from functools import partial | ||
from tqdm import tqdm | ||
from torch.utils.data import Dataset | ||
from torch.distributed import get_rank | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
|
||
def zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left", value: int = 0) -> torch.Tensor: | ||
assert side in ("left", "right") | ||
max_len = max(seq.size(-1) for seq in sequences) | ||
padded_sequences = [] | ||
for seq in sequences: | ||
pad_len = max_len - seq.size(-1) | ||
padding = (pad_len, 0) if side == "left" else (0, pad_len) | ||
padded_sequences.append(F.pad(seq, padding, value=value)) | ||
return torch.stack(padded_sequences, dim=0) | ||
|
||
|
||
class OfflineRLDataset(Dataset): | ||
""" | ||
Overview: | ||
PyTorch Dataset for OfflineRL LLM training like KTO and DPO. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dataset: Iterable[Dict], | ||
tokenizer, | ||
max_length: int, | ||
input_key: str = "input", | ||
output_key: str = "output", | ||
label_key: str = "label", | ||
apply_chat_template: bool = False, | ||
tokenizer_chat_template: str = None, | ||
input_template: str = None, | ||
num_processors: int = 8, | ||
parallel_load: bool = True | ||
) -> None: | ||
super().__init__() | ||
self.tokenizer = tokenizer | ||
self.max_length = max_length | ||
|
||
if apply_chat_template: | ||
apply_chat_template = self.tokenizer.apply_chat_template | ||
if tokenizer_chat_template: | ||
self.tokenizer.chat_template = tokenizer_chat_template | ||
|
||
# Parallel loading datasets | ||
if parallel_load: | ||
preprocess_data_fn = partial( | ||
self._preprocess_data, | ||
input_template=input_template, | ||
input_key=input_key, | ||
output_key=output_key, | ||
label_key=label_key, | ||
apply_chat_template=apply_chat_template | ||
) | ||
processed_dataset = dataset.map( | ||
preprocess_data_fn, remove_columns=dataset.column_names, num_proc=num_processors | ||
) | ||
# preprocess function may return None, so filter out the None | ||
processed_dataset = processed_dataset.filter(lambda x: x["prompt"] is not None) | ||
|
||
self.prompts = processed_dataset["prompt"] | ||
self.responses = processed_dataset["response"] | ||
self.labels = processed_dataset["label"] | ||
self.prompt_ids_lens = processed_dataset["prompt_ids_len"] | ||
else: | ||
self.prompts = [] | ||
self.responses = [] | ||
self.labels = [] | ||
self.prompt_ids_lens = [] | ||
for data in tqdm(dataset, desc="Preprocessing data", disable=not get_rank() == 0): | ||
processed_data = self._preprocess_data(data) | ||
if processed_data["prompt"] is not None: | ||
self.prompts.append(processed_data["prompt"]) | ||
self.responses.append(processed_data["response"]) | ||
self.labels.append(processed_data["label"]) | ||
self.prompt_ids_lens.append(processed_data["prompt_ids_len"]) | ||
|
||
def _preprocess_data( | ||
self, | ||
data: Dict[str, Any], | ||
input_template: str = None, | ||
input_key: str = "input", | ||
output_key: str = "output", | ||
label_key: str = "label", | ||
apply_chat_template: Union[bool, Callable] = False, | ||
) -> str: | ||
label = data[label_key] | ||
|
||
if apply_chat_template: | ||
if output_key: | ||
prompt = apply_chat_template(data[input_key], tokenize=False, add_generation_prompt=True) | ||
response = apply_chat_template(data[input_key] + data[output_key], tokenize=False)[len(prompt):] | ||
else: | ||
prompt = apply_chat_template(data[input_key][:-1], tokenize=False, add_generation_prompt=True) | ||
response = apply_chat_template(data[input_key], tokenize=False)[len(prompt):] | ||
else: | ||
prompt = data[input_key] | ||
response = data[output_key] | ||
if input_template: | ||
prompt = input_template.format(prompt) | ||
|
||
prompt_token = self.tokenizer( | ||
prompt, | ||
max_length=self.max_length, | ||
# use the batch max length (in `collate_fn`) to pad rather than the global max length | ||
padding=False, | ||
truncation=True, | ||
return_tensors="pt", | ||
# add special tokens for the prompt in `collate_fn` | ||
add_special_tokens=False, | ||
) | ||
prompt_ids_len = prompt_token["attention_mask"].int().sum().item() | ||
|
||
# filter the sample whose length is greater than max_length (2 for answer length) | ||
if prompt_ids_len >= self.max_length - 2: | ||
prompt = None | ||
|
||
return {"prompt": prompt, "response": response, "label": label, "prompt_ids_len": prompt_ids_len} | ||
|
||
def __len__(self) -> int: | ||
""" | ||
Overview: | ||
Get the length of the dataset. | ||
Returns: | ||
- length (int): The length of the dataset. | ||
""" | ||
return len(self.prompts) | ||
|
||
def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, int]]: | ||
""" | ||
Overview: | ||
Get the item at the given index. | ||
Returns: | ||
- item (Dict[str, Union[torch.Tensor, int]]): The item at the given index. | ||
""" | ||
return { | ||
"prompt": self.prompts[idx], | ||
"response": self.responses[idx], | ||
"label": self.labels[idx], | ||
"prompt_ids_len": self.prompt_ids_lens[idx] | ||
} | ||
|
||
def collate_fn(self, item_list: List[Dict[str, Union[torch.Tensor, int]]]): | ||
|
||
def tokenizer(prompt: str, response: str): | ||
text = (prompt + response).rstrip("\n") | ||
if not text.endswith(self.tokenizer.eos_token): | ||
text += " " + self.tokenizer.eos_token | ||
inputs = self.tokenizer( | ||
text, | ||
max_length=self.max_length, | ||
padding=False, | ||
truncation=True, | ||
return_tensors="pt", | ||
add_special_tokens=False, | ||
) | ||
|
||
inputs["input_ids"][0][-1] = self.tokenizer.eos_token_id | ||
inputs["attention_mask"][0][-1] = True | ||
return inputs["input_ids"], inputs["attention_mask"] | ||
|
||
tot_ids, tot_masks, tot_labels, prompt_ids_lens = [], [], [], [] | ||
for item in item_list: | ||
input_ids, attention_mask = tokenizer(item["prompt"], item["response"]) | ||
tot_ids.append(input_ids) | ||
tot_masks.append(attention_mask) | ||
tot_labels.append(item["label"]) | ||
prompt_ids_lens.append(item["prompt_ids_len"]) | ||
|
||
# add unmatched y'| x (used to estimate the KL divergence between policy and reference) | ||
for idx in range(len(item_list)): | ||
next_idx = (idx + 1) % len(item_list) | ||
input_ids, attention_mask = tokenizer(item_list[idx]["prompt"], item_list[next_idx]["response"]) | ||
tot_ids.append(input_ids) | ||
tot_masks.append(attention_mask) | ||
tot_labels.append(-1) | ||
prompt_ids_lens.append(item_list[idx]["prompt_ids_len"]) | ||
|
||
input_ids = zero_pad_sequences(tot_ids, side="right", value=self.tokenizer.pad_token_id) | ||
attention_mask = zero_pad_sequences(tot_masks, side="right") | ||
return input_ids, attention_mask, torch.LongTensor(tot_labels), prompt_ids_lens |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from typing import Any, Dict, Union, Callable, Iterable | ||
from tqdm import tqdm | ||
from torch.utils.data import Dataset | ||
from torch.distributed import get_rank | ||
|
||
|
||
class OnlineRLDataset(Dataset): | ||
""" | ||
Overview: | ||
PyTorch Dataset for OnlineRL LLM training like PPO. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dataset: Iterable[Dict], | ||
tokenizer, | ||
input_key: str = "input", | ||
apply_chat_template: bool = False, | ||
input_template: str = None, | ||
) -> None: | ||
""" | ||
Overview: | ||
Initialize the OnlineRLDataset. | ||
Arguments: | ||
- dataset (torch.utils.data.Dataset): The dataset to preprocess. | ||
- tokenizer (): The tokenizer to preprocess the data. | ||
- input_key (str): The key of the input data. | ||
- apply_chat_template (bool): Whether to apply the chat template. | ||
- input_template (str): The template to format the data. | ||
""" | ||
super().__init__() | ||
self.tokenizer = tokenizer | ||
self.input_template = input_template | ||
|
||
if apply_chat_template: | ||
apply_chat_template = self.tokenizer.apply_chat_template | ||
|
||
self.prompts = [] | ||
try: | ||
rank = get_rank() | ||
except ValueError: # not initialized yet, which is the case in unit test | ||
rank = 0 | ||
for data in tqdm(dataset, desc="Preprocessing data", disable=not rank == 0): | ||
prompt = self._preprocess_data(data, input_template, input_key, apply_chat_template) | ||
self.prompts.append(prompt) | ||
|
||
def __len__(self) -> int: | ||
""" | ||
Overview: | ||
Get the length of the dataset. | ||
Returns: | ||
- length (int): The length of the dataset. | ||
""" | ||
return len(self.prompts) | ||
|
||
def __getitem__(self, idx: int) -> str: | ||
""" | ||
Overview: | ||
Get the item at the given index. | ||
Args: | ||
- idx (int): The index of the item to get. | ||
Returns: | ||
- item (str): The item at the given index. | ||
""" | ||
return self.prompts[idx] | ||
|
||
def _preprocess_data( | ||
self, | ||
data: Dict[str, Any], | ||
input_template: str = None, | ||
input_key: str = "input", | ||
apply_chat_template: Union[bool, Callable] = False, | ||
) -> str: | ||
""" | ||
Overview: | ||
Preprocess the data to get the formatted prompt. | ||
Arguments: | ||
- data (Dict[str, Any]): The data to preprocess. | ||
- input_template (str): The template to format the data. | ||
- input_key (str): The key of the input data. | ||
- apply_chat_template (Union[bool, Callable]): The function to apply the chat template, \ | ||
usually is the `tokenizer.apply_chat_template`. | ||
Returns: | ||
- prompt (str): The formatted prompt. | ||
""" | ||
if apply_chat_template: | ||
chat = data[input_key] | ||
if isinstance(chat, str): | ||
chat = [{"role": "user", "content": chat}] | ||
assert isinstance(chat, list) and all(isinstance(t, dict) for t in chat), "chat must be a list of dict" | ||
prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True) | ||
else: | ||
prompt = data[input_key] | ||
if input_template: | ||
prompt = input_template.format(prompt) | ||
return prompt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import pytest | ||
from datasets import load_dataset, concatenate_datasets | ||
from rl.data.offlinerl_dataset import OfflineRLDataset | ||
from transformers import AutoTokenizer | ||
|
||
|
||
@pytest.fixture | ||
def dataset(): | ||
# Load a sample dataset | ||
hf_dataset = load_dataset("MMInstruction/VL-RewardBench", split='test') | ||
# split pair data into two separate datasets | ||
hf_dataset_1 = hf_dataset.map( | ||
lambda x: { | ||
"prompt": x["query"], | ||
"response": x["response"][0], | ||
'human_ranking': x["human_ranking"][0] | ||
} | ||
) | ||
hf_dataset_2 = hf_dataset.map( | ||
lambda x: { | ||
"prompt": x["query"], | ||
"response": x["response"][1], | ||
'human_ranking': x["human_ranking"][0] | ||
} | ||
) | ||
# combine two datasets | ||
hf_dataset = concatenate_datasets([hf_dataset_1, hf_dataset_2]) | ||
# shuffle the dataset | ||
hf_dataset = hf_dataset.shuffle(seed=42) | ||
return hf_dataset | ||
|
||
|
||
@pytest.fixture | ||
def tokenizer(): | ||
# Load a tokenizer | ||
return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-7B") | ||
|
||
|
||
@pytest.mark.unittest | ||
def test_offline_rl_dataset_initialization(dataset, tokenizer): | ||
# Test the initialization of the OfflineRLDataset | ||
offline_dataset = OfflineRLDataset( | ||
dataset=dataset, | ||
tokenizer=tokenizer, | ||
max_length=1024, | ||
input_key="query", | ||
output_key="response", | ||
label_key="human_ranking" | ||
) | ||
assert len(offline_dataset) == len(dataset) | ||
offline_dataset = OfflineRLDataset( | ||
dataset=dataset, | ||
tokenizer=tokenizer, | ||
max_length=256, | ||
input_key="query", | ||
output_key="response", | ||
label_key="human_ranking" | ||
) | ||
# lower max_length will filter out some samples | ||
assert len(offline_dataset) < len(dataset) | ||
|
||
|
||
@pytest.mark.unittest | ||
def test_offline_rl_dataset_item_retrieval(dataset, tokenizer): | ||
# Test retrieving an item from the OfflineRLDataset | ||
offline_dataset = OfflineRLDataset( | ||
dataset=dataset, | ||
tokenizer=tokenizer, | ||
max_length=256, | ||
input_key="query", | ||
output_key="response", | ||
label_key="human_ranking" | ||
) | ||
item = offline_dataset[0] | ||
assert "prompt" in item | ||
assert "response" in item | ||
assert "label" in item | ||
assert "prompt_ids_len" in item | ||
print(item) | ||
|
||
|
||
@pytest.mark.unittest | ||
def test_offline_rl_dataset_collate_fn(dataset, tokenizer): | ||
# Test the collate function of the OfflineRLDataset | ||
offline_dataset = OfflineRLDataset( | ||
dataset=dataset, | ||
tokenizer=tokenizer, | ||
max_length=256, | ||
input_key="query", | ||
output_key="response", | ||
label_key="human_ranking" | ||
) | ||
B = 10 | ||
item_list = [offline_dataset[i] for i in range(B)] | ||
input_ids, attention_mask, labels, prompt_ids_lens = offline_dataset.collate_fn(item_list) | ||
assert input_ids.size(0) == len(item_list) * 2 # because of the unmatched y'| x | ||
assert attention_mask.size(0) == len(item_list) * 2 | ||
assert labels.size(0) == len(item_list) * 2 | ||
assert len(prompt_ids_lens) == len(item_list) * 2 |
Oops, something went wrong.