Skip to content

Commit

Permalink
feature(nyz): add rlhf dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Jan 24, 2025
1 parent bf258f8 commit 702f244
Show file tree
Hide file tree
Showing 5 changed files with 422 additions and 0 deletions.
186 changes: 186 additions & 0 deletions ding/utils/data/rlhf_offline_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
from typing import Iterable, Dict, List, Union, Any, Callable
from functools import partial
from tqdm import tqdm
from torch.utils.data import Dataset
from torch.distributed import get_rank
import torch
import torch.nn.functional as F


def zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left", value: int = 0) -> torch.Tensor:
assert side in ("left", "right")
max_len = max(seq.size(-1) for seq in sequences)
padded_sequences = []
for seq in sequences:
pad_len = max_len - seq.size(-1)
padding = (pad_len, 0) if side == "left" else (0, pad_len)
padded_sequences.append(F.pad(seq, padding, value=value))
return torch.stack(padded_sequences, dim=0)


class OfflineRLDataset(Dataset):
"""
Overview:
PyTorch Dataset for OfflineRL LLM training like KTO and DPO.
"""

def __init__(
self,
dataset: Iterable[Dict],
tokenizer,
max_length: int,
input_key: str = "input",
output_key: str = "output",
label_key: str = "label",
apply_chat_template: bool = False,
tokenizer_chat_template: str = None,
input_template: str = None,
num_processors: int = 8,
parallel_load: bool = True
) -> None:
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length

if apply_chat_template:
apply_chat_template = self.tokenizer.apply_chat_template
if tokenizer_chat_template:
self.tokenizer.chat_template = tokenizer_chat_template

# Parallel loading datasets
if parallel_load:
preprocess_data_fn = partial(
self._preprocess_data,
input_template=input_template,
input_key=input_key,
output_key=output_key,
label_key=label_key,
apply_chat_template=apply_chat_template
)
processed_dataset = dataset.map(
preprocess_data_fn, remove_columns=dataset.column_names, num_proc=num_processors
)
# preprocess function may return None, so filter out the None
processed_dataset = processed_dataset.filter(lambda x: x["prompt"] is not None)

self.prompts = processed_dataset["prompt"]
self.responses = processed_dataset["response"]
self.labels = processed_dataset["label"]
self.prompt_ids_lens = processed_dataset["prompt_ids_len"]
else:
self.prompts = []
self.responses = []
self.labels = []
self.prompt_ids_lens = []
for data in tqdm(dataset, desc="Preprocessing data", disable=not get_rank() == 0):
processed_data = self._preprocess_data(data)
if processed_data["prompt"] is not None:
self.prompts.append(processed_data["prompt"])
self.responses.append(processed_data["response"])
self.labels.append(processed_data["label"])
self.prompt_ids_lens.append(processed_data["prompt_ids_len"])

def _preprocess_data(
self,
data: Dict[str, Any],
input_template: str = None,
input_key: str = "input",
output_key: str = "output",
label_key: str = "label",
apply_chat_template: Union[bool, Callable] = False,
) -> str:
label = data[label_key]

if apply_chat_template:
if output_key:
prompt = apply_chat_template(data[input_key], tokenize=False, add_generation_prompt=True)
response = apply_chat_template(data[input_key] + data[output_key], tokenize=False)[len(prompt):]
else:
prompt = apply_chat_template(data[input_key][:-1], tokenize=False, add_generation_prompt=True)
response = apply_chat_template(data[input_key], tokenize=False)[len(prompt):]
else:
prompt = data[input_key]
response = data[output_key]
if input_template:
prompt = input_template.format(prompt)

prompt_token = self.tokenizer(
prompt,
max_length=self.max_length,
# use the batch max length (in `collate_fn`) to pad rather than the global max length
padding=False,
truncation=True,
return_tensors="pt",
# add special tokens for the prompt in `collate_fn`
add_special_tokens=False,
)
prompt_ids_len = prompt_token["attention_mask"].int().sum().item()

# filter the sample whose length is greater than max_length (2 for answer length)
if prompt_ids_len >= self.max_length - 2:
prompt = None

return {"prompt": prompt, "response": response, "label": label, "prompt_ids_len": prompt_ids_len}

def __len__(self) -> int:
"""
Overview:
Get the length of the dataset.
Returns:
- length (int): The length of the dataset.
"""
return len(self.prompts)

def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, int]]:
"""
Overview:
Get the item at the given index.
Returns:
- item (Dict[str, Union[torch.Tensor, int]]): The item at the given index.
"""
return {
"prompt": self.prompts[idx],
"response": self.responses[idx],
"label": self.labels[idx],
"prompt_ids_len": self.prompt_ids_lens[idx]
}

def collate_fn(self, item_list: List[Dict[str, Union[torch.Tensor, int]]]):

def tokenizer(prompt: str, response: str):
text = (prompt + response).rstrip("\n")
if not text.endswith(self.tokenizer.eos_token):
text += " " + self.tokenizer.eos_token
inputs = self.tokenizer(
text,
max_length=self.max_length,
padding=False,
truncation=True,
return_tensors="pt",
add_special_tokens=False,
)

inputs["input_ids"][0][-1] = self.tokenizer.eos_token_id
inputs["attention_mask"][0][-1] = True
return inputs["input_ids"], inputs["attention_mask"]

tot_ids, tot_masks, tot_labels, prompt_ids_lens = [], [], [], []
for item in item_list:
input_ids, attention_mask = tokenizer(item["prompt"], item["response"])
tot_ids.append(input_ids)
tot_masks.append(attention_mask)
tot_labels.append(item["label"])
prompt_ids_lens.append(item["prompt_ids_len"])

# add unmatched y'| x (used to estimate the KL divergence between policy and reference)
for idx in range(len(item_list)):
next_idx = (idx + 1) % len(item_list)
input_ids, attention_mask = tokenizer(item_list[idx]["prompt"], item_list[next_idx]["response"])
tot_ids.append(input_ids)
tot_masks.append(attention_mask)
tot_labels.append(-1)
prompt_ids_lens.append(item_list[idx]["prompt_ids_len"])

input_ids = zero_pad_sequences(tot_ids, side="right", value=self.tokenizer.pad_token_id)
attention_mask = zero_pad_sequences(tot_masks, side="right")
return input_ids, attention_mask, torch.LongTensor(tot_labels), prompt_ids_lens
96 changes: 96 additions & 0 deletions ding/utils/data/rlhf_online_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import Any, Dict, Union, Callable, Iterable
from tqdm import tqdm
from torch.utils.data import Dataset
from torch.distributed import get_rank


class OnlineRLDataset(Dataset):
"""
Overview:
PyTorch Dataset for OnlineRL LLM training like PPO.
"""

def __init__(
self,
dataset: Iterable[Dict],
tokenizer,
input_key: str = "input",
apply_chat_template: bool = False,
input_template: str = None,
) -> None:
"""
Overview:
Initialize the OnlineRLDataset.
Arguments:
- dataset (torch.utils.data.Dataset): The dataset to preprocess.
- tokenizer (): The tokenizer to preprocess the data.
- input_key (str): The key of the input data.
- apply_chat_template (bool): Whether to apply the chat template.
- input_template (str): The template to format the data.
"""
super().__init__()
self.tokenizer = tokenizer
self.input_template = input_template

if apply_chat_template:
apply_chat_template = self.tokenizer.apply_chat_template

self.prompts = []
try:
rank = get_rank()
except ValueError: # not initialized yet, which is the case in unit test
rank = 0
for data in tqdm(dataset, desc="Preprocessing data", disable=not rank == 0):
prompt = self._preprocess_data(data, input_template, input_key, apply_chat_template)
self.prompts.append(prompt)

def __len__(self) -> int:
"""
Overview:
Get the length of the dataset.
Returns:
- length (int): The length of the dataset.
"""
return len(self.prompts)

def __getitem__(self, idx: int) -> str:
"""
Overview:
Get the item at the given index.
Args:
- idx (int): The index of the item to get.
Returns:
- item (str): The item at the given index.
"""
return self.prompts[idx]

def _preprocess_data(
self,
data: Dict[str, Any],
input_template: str = None,
input_key: str = "input",
apply_chat_template: Union[bool, Callable] = False,
) -> str:
"""
Overview:
Preprocess the data to get the formatted prompt.
Arguments:
- data (Dict[str, Any]): The data to preprocess.
- input_template (str): The template to format the data.
- input_key (str): The key of the input data.
- apply_chat_template (Union[bool, Callable]): The function to apply the chat template, \
usually is the `tokenizer.apply_chat_template`.
Returns:
- prompt (str): The formatted prompt.
"""
if apply_chat_template:
chat = data[input_key]
if isinstance(chat, str):
chat = [{"role": "user", "content": chat}]
assert isinstance(chat, list) and all(isinstance(t, dict) for t in chat), "chat must be a list of dict"
prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
else:
prompt = data[input_key]
if input_template:
prompt = input_template.format(prompt)
return prompt
99 changes: 99 additions & 0 deletions ding/utils/data/tests/test_rlhf_offline_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import pytest
from datasets import load_dataset, concatenate_datasets
from rl.data.offlinerl_dataset import OfflineRLDataset
from transformers import AutoTokenizer


@pytest.fixture
def dataset():
# Load a sample dataset
hf_dataset = load_dataset("MMInstruction/VL-RewardBench", split='test')
# split pair data into two separate datasets
hf_dataset_1 = hf_dataset.map(
lambda x: {
"prompt": x["query"],
"response": x["response"][0],
'human_ranking': x["human_ranking"][0]
}
)
hf_dataset_2 = hf_dataset.map(
lambda x: {
"prompt": x["query"],
"response": x["response"][1],
'human_ranking': x["human_ranking"][0]
}
)
# combine two datasets
hf_dataset = concatenate_datasets([hf_dataset_1, hf_dataset_2])
# shuffle the dataset
hf_dataset = hf_dataset.shuffle(seed=42)
return hf_dataset


@pytest.fixture
def tokenizer():
# Load a tokenizer
return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-7B")


@pytest.mark.unittest
def test_offline_rl_dataset_initialization(dataset, tokenizer):
# Test the initialization of the OfflineRLDataset
offline_dataset = OfflineRLDataset(
dataset=dataset,
tokenizer=tokenizer,
max_length=1024,
input_key="query",
output_key="response",
label_key="human_ranking"
)
assert len(offline_dataset) == len(dataset)
offline_dataset = OfflineRLDataset(
dataset=dataset,
tokenizer=tokenizer,
max_length=256,
input_key="query",
output_key="response",
label_key="human_ranking"
)
# lower max_length will filter out some samples
assert len(offline_dataset) < len(dataset)


@pytest.mark.unittest
def test_offline_rl_dataset_item_retrieval(dataset, tokenizer):
# Test retrieving an item from the OfflineRLDataset
offline_dataset = OfflineRLDataset(
dataset=dataset,
tokenizer=tokenizer,
max_length=256,
input_key="query",
output_key="response",
label_key="human_ranking"
)
item = offline_dataset[0]
assert "prompt" in item
assert "response" in item
assert "label" in item
assert "prompt_ids_len" in item
print(item)


@pytest.mark.unittest
def test_offline_rl_dataset_collate_fn(dataset, tokenizer):
# Test the collate function of the OfflineRLDataset
offline_dataset = OfflineRLDataset(
dataset=dataset,
tokenizer=tokenizer,
max_length=256,
input_key="query",
output_key="response",
label_key="human_ranking"
)
B = 10
item_list = [offline_dataset[i] for i in range(B)]
input_ids, attention_mask, labels, prompt_ids_lens = offline_dataset.collate_fn(item_list)
assert input_ids.size(0) == len(item_list) * 2 # because of the unmatched y'| x
assert attention_mask.size(0) == len(item_list) * 2
assert labels.size(0) == len(item_list) * 2
assert len(prompt_ids_lens) == len(item_list) * 2
Loading

0 comments on commit 702f244

Please sign in to comment.