diff --git a/ding/utils/data/rlhf_offline_dataset.py b/ding/utils/data/rlhf_offline_dataset.py new file mode 100644 index 0000000000..56596b658e --- /dev/null +++ b/ding/utils/data/rlhf_offline_dataset.py @@ -0,0 +1,186 @@ +from typing import Iterable, Dict, List, Union, Any, Callable +from functools import partial +from tqdm import tqdm +from torch.utils.data import Dataset +from torch.distributed import get_rank +import torch +import torch.nn.functional as F + + +def zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left", value: int = 0) -> torch.Tensor: + assert side in ("left", "right") + max_len = max(seq.size(-1) for seq in sequences) + padded_sequences = [] + for seq in sequences: + pad_len = max_len - seq.size(-1) + padding = (pad_len, 0) if side == "left" else (0, pad_len) + padded_sequences.append(F.pad(seq, padding, value=value)) + return torch.stack(padded_sequences, dim=0) + + +class OfflineRLDataset(Dataset): + """ + Overview: + PyTorch Dataset for OfflineRL LLM training like KTO and DPO. + """ + + def __init__( + self, + dataset: Iterable[Dict], + tokenizer, + max_length: int, + input_key: str = "input", + output_key: str = "output", + label_key: str = "label", + apply_chat_template: bool = False, + tokenizer_chat_template: str = None, + input_template: str = None, + num_processors: int = 8, + parallel_load: bool = True + ) -> None: + super().__init__() + self.tokenizer = tokenizer + self.max_length = max_length + + if apply_chat_template: + apply_chat_template = self.tokenizer.apply_chat_template + if tokenizer_chat_template: + self.tokenizer.chat_template = tokenizer_chat_template + + # Parallel loading datasets + if parallel_load: + preprocess_data_fn = partial( + self._preprocess_data, + input_template=input_template, + input_key=input_key, + output_key=output_key, + label_key=label_key, + apply_chat_template=apply_chat_template + ) + processed_dataset = dataset.map( + preprocess_data_fn, remove_columns=dataset.column_names, num_proc=num_processors + ) + # preprocess function may return None, so filter out the None + processed_dataset = processed_dataset.filter(lambda x: x["prompt"] is not None) + + self.prompts = processed_dataset["prompt"] + self.responses = processed_dataset["response"] + self.labels = processed_dataset["label"] + self.prompt_ids_lens = processed_dataset["prompt_ids_len"] + else: + self.prompts = [] + self.responses = [] + self.labels = [] + self.prompt_ids_lens = [] + for data in tqdm(dataset, desc="Preprocessing data", disable=not get_rank() == 0): + processed_data = self._preprocess_data(data) + if processed_data["prompt"] is not None: + self.prompts.append(processed_data["prompt"]) + self.responses.append(processed_data["response"]) + self.labels.append(processed_data["label"]) + self.prompt_ids_lens.append(processed_data["prompt_ids_len"]) + + def _preprocess_data( + self, + data: Dict[str, Any], + input_template: str = None, + input_key: str = "input", + output_key: str = "output", + label_key: str = "label", + apply_chat_template: Union[bool, Callable] = False, + ) -> str: + label = data[label_key] + + if apply_chat_template: + if output_key: + prompt = apply_chat_template(data[input_key], tokenize=False, add_generation_prompt=True) + response = apply_chat_template(data[input_key] + data[output_key], tokenize=False)[len(prompt):] + else: + prompt = apply_chat_template(data[input_key][:-1], tokenize=False, add_generation_prompt=True) + response = apply_chat_template(data[input_key], tokenize=False)[len(prompt):] + else: + prompt = data[input_key] + response = data[output_key] + if input_template: + prompt = input_template.format(prompt) + + prompt_token = self.tokenizer( + prompt, + max_length=self.max_length, + # use the batch max length (in `collate_fn`) to pad rather than the global max length + padding=False, + truncation=True, + return_tensors="pt", + # add special tokens for the prompt in `collate_fn` + add_special_tokens=False, + ) + prompt_ids_len = prompt_token["attention_mask"].int().sum().item() + + # filter the sample whose length is greater than max_length (2 for answer length) + if prompt_ids_len >= self.max_length - 2: + prompt = None + + return {"prompt": prompt, "response": response, "label": label, "prompt_ids_len": prompt_ids_len} + + def __len__(self) -> int: + """ + Overview: + Get the length of the dataset. + Returns: + - length (int): The length of the dataset. + """ + return len(self.prompts) + + def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, int]]: + """ + Overview: + Get the item at the given index. + Returns: + - item (Dict[str, Union[torch.Tensor, int]]): The item at the given index. + """ + return { + "prompt": self.prompts[idx], + "response": self.responses[idx], + "label": self.labels[idx], + "prompt_ids_len": self.prompt_ids_lens[idx] + } + + def collate_fn(self, item_list: List[Dict[str, Union[torch.Tensor, int]]]): + + def tokenizer(prompt: str, response: str): + text = (prompt + response).rstrip("\n") + if not text.endswith(self.tokenizer.eos_token): + text += " " + self.tokenizer.eos_token + inputs = self.tokenizer( + text, + max_length=self.max_length, + padding=False, + truncation=True, + return_tensors="pt", + add_special_tokens=False, + ) + + inputs["input_ids"][0][-1] = self.tokenizer.eos_token_id + inputs["attention_mask"][0][-1] = True + return inputs["input_ids"], inputs["attention_mask"] + + tot_ids, tot_masks, tot_labels, prompt_ids_lens = [], [], [], [] + for item in item_list: + input_ids, attention_mask = tokenizer(item["prompt"], item["response"]) + tot_ids.append(input_ids) + tot_masks.append(attention_mask) + tot_labels.append(item["label"]) + prompt_ids_lens.append(item["prompt_ids_len"]) + + # add unmatched y'| x (used to estimate the KL divergence between policy and reference) + for idx in range(len(item_list)): + next_idx = (idx + 1) % len(item_list) + input_ids, attention_mask = tokenizer(item_list[idx]["prompt"], item_list[next_idx]["response"]) + tot_ids.append(input_ids) + tot_masks.append(attention_mask) + tot_labels.append(-1) + prompt_ids_lens.append(item_list[idx]["prompt_ids_len"]) + + input_ids = zero_pad_sequences(tot_ids, side="right", value=self.tokenizer.pad_token_id) + attention_mask = zero_pad_sequences(tot_masks, side="right") + return input_ids, attention_mask, torch.LongTensor(tot_labels), prompt_ids_lens diff --git a/ding/utils/data/rlhf_online_dataset.py b/ding/utils/data/rlhf_online_dataset.py new file mode 100644 index 0000000000..b192fe455e --- /dev/null +++ b/ding/utils/data/rlhf_online_dataset.py @@ -0,0 +1,96 @@ +from typing import Any, Dict, Union, Callable, Iterable +from tqdm import tqdm +from torch.utils.data import Dataset +from torch.distributed import get_rank + + +class OnlineRLDataset(Dataset): + """ + Overview: + PyTorch Dataset for OnlineRL LLM training like PPO. + """ + + def __init__( + self, + dataset: Iterable[Dict], + tokenizer, + input_key: str = "input", + apply_chat_template: bool = False, + input_template: str = None, + ) -> None: + """ + Overview: + Initialize the OnlineRLDataset. + Arguments: + - dataset (torch.utils.data.Dataset): The dataset to preprocess. + - tokenizer (): The tokenizer to preprocess the data. + - input_key (str): The key of the input data. + - apply_chat_template (bool): Whether to apply the chat template. + - input_template (str): The template to format the data. + """ + super().__init__() + self.tokenizer = tokenizer + self.input_template = input_template + + if apply_chat_template: + apply_chat_template = self.tokenizer.apply_chat_template + + self.prompts = [] + try: + rank = get_rank() + except ValueError: # not initialized yet, which is the case in unit test + rank = 0 + for data in tqdm(dataset, desc="Preprocessing data", disable=not rank == 0): + prompt = self._preprocess_data(data, input_template, input_key, apply_chat_template) + self.prompts.append(prompt) + + def __len__(self) -> int: + """ + Overview: + Get the length of the dataset. + Returns: + - length (int): The length of the dataset. + """ + return len(self.prompts) + + def __getitem__(self, idx: int) -> str: + """ + Overview: + Get the item at the given index. + Args: + - idx (int): The index of the item to get. + Returns: + - item (str): The item at the given index. + """ + return self.prompts[idx] + + def _preprocess_data( + self, + data: Dict[str, Any], + input_template: str = None, + input_key: str = "input", + apply_chat_template: Union[bool, Callable] = False, + ) -> str: + """ + Overview: + Preprocess the data to get the formatted prompt. + Arguments: + - data (Dict[str, Any]): The data to preprocess. + - input_template (str): The template to format the data. + - input_key (str): The key of the input data. + - apply_chat_template (Union[bool, Callable]): The function to apply the chat template, \ + usually is the `tokenizer.apply_chat_template`. + Returns: + - prompt (str): The formatted prompt. + """ + if apply_chat_template: + chat = data[input_key] + if isinstance(chat, str): + chat = [{"role": "user", "content": chat}] + assert isinstance(chat, list) and all(isinstance(t, dict) for t in chat), "chat must be a list of dict" + prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True) + else: + prompt = data[input_key] + if input_template: + prompt = input_template.format(prompt) + return prompt diff --git a/ding/utils/data/tests/test_rlhf_offline_dataset.py b/ding/utils/data/tests/test_rlhf_offline_dataset.py new file mode 100644 index 0000000000..dfc889f217 --- /dev/null +++ b/ding/utils/data/tests/test_rlhf_offline_dataset.py @@ -0,0 +1,99 @@ +import pytest +from datasets import load_dataset, concatenate_datasets +from rl.data.offlinerl_dataset import OfflineRLDataset +from transformers import AutoTokenizer + + +@pytest.fixture +def dataset(): + # Load a sample dataset + hf_dataset = load_dataset("MMInstruction/VL-RewardBench", split='test') + # split pair data into two separate datasets + hf_dataset_1 = hf_dataset.map( + lambda x: { + "prompt": x["query"], + "response": x["response"][0], + 'human_ranking': x["human_ranking"][0] + } + ) + hf_dataset_2 = hf_dataset.map( + lambda x: { + "prompt": x["query"], + "response": x["response"][1], + 'human_ranking': x["human_ranking"][0] + } + ) + # combine two datasets + hf_dataset = concatenate_datasets([hf_dataset_1, hf_dataset_2]) + # shuffle the dataset + hf_dataset = hf_dataset.shuffle(seed=42) + return hf_dataset + + +@pytest.fixture +def tokenizer(): + # Load a tokenizer + return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-7B") + + +@pytest.mark.unittest +def test_offline_rl_dataset_initialization(dataset, tokenizer): + # Test the initialization of the OfflineRLDataset + offline_dataset = OfflineRLDataset( + dataset=dataset, + tokenizer=tokenizer, + max_length=1024, + input_key="query", + output_key="response", + label_key="human_ranking" + ) + assert len(offline_dataset) == len(dataset) + offline_dataset = OfflineRLDataset( + dataset=dataset, + tokenizer=tokenizer, + max_length=256, + input_key="query", + output_key="response", + label_key="human_ranking" + ) + # lower max_length will filter out some samples + assert len(offline_dataset) < len(dataset) + + +@pytest.mark.unittest +def test_offline_rl_dataset_item_retrieval(dataset, tokenizer): + # Test retrieving an item from the OfflineRLDataset + offline_dataset = OfflineRLDataset( + dataset=dataset, + tokenizer=tokenizer, + max_length=256, + input_key="query", + output_key="response", + label_key="human_ranking" + ) + item = offline_dataset[0] + assert "prompt" in item + assert "response" in item + assert "label" in item + assert "prompt_ids_len" in item + print(item) + + +@pytest.mark.unittest +def test_offline_rl_dataset_collate_fn(dataset, tokenizer): + # Test the collate function of the OfflineRLDataset + offline_dataset = OfflineRLDataset( + dataset=dataset, + tokenizer=tokenizer, + max_length=256, + input_key="query", + output_key="response", + label_key="human_ranking" + ) + B = 10 + item_list = [offline_dataset[i] for i in range(B)] + input_ids, attention_mask, labels, prompt_ids_lens = offline_dataset.collate_fn(item_list) + assert input_ids.size(0) == len(item_list) * 2 # because of the unmatched y'| x + assert attention_mask.size(0) == len(item_list) * 2 + assert labels.size(0) == len(item_list) * 2 + assert len(prompt_ids_lens) == len(item_list) * 2 diff --git a/ding/utils/data/tests/test_rlhf_online_dataset.py b/ding/utils/data/tests/test_rlhf_online_dataset.py new file mode 100644 index 0000000000..f0b7e2e2f0 --- /dev/null +++ b/ding/utils/data/tests/test_rlhf_online_dataset.py @@ -0,0 +1,39 @@ +import pytest +from datasets import load_dataset +from transformers import AutoTokenizer +from rl.data.onlinerl_dataset import OnlineRLDataset + + +@pytest.fixture +def dataset(): + # Load the dataset + hf_dataset = load_dataset("cat-searcher/minif2f-lean4")['validation'] + print(hf_dataset) + return hf_dataset + + +@pytest.fixture +def tokenizer(): + return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-7B") + + +@pytest.mark.unittest +def test_onlinerl_dataset_initialization(dataset, tokenizer): + # Initialize OnlineRLDataset + online_rl_dataset = OnlineRLDataset( + dataset=dataset, tokenizer=tokenizer, input_key="formal_statement", apply_chat_template=True + ) + # Check if the dataset is initialized correctly + assert len(online_rl_dataset) == len(dataset) + + +@pytest.mark.unittest +def test_onlinerl_dataset_getitem(dataset, tokenizer): + # Initialize OnlineRLDataset + online_rl_dataset = OnlineRLDataset( + dataset=dataset, tokenizer=tokenizer, input_key="formal_statement", apply_chat_template=True + ) + # Check if __getitem__ returns the expected formatted prompt + item = online_rl_dataset[0] + print(item) + assert isinstance(item, str) diff --git a/setup.py b/setup.py index 1bcf3f8fcd..f3d60222f1 100644 --- a/setup.py +++ b/setup.py @@ -79,6 +79,8 @@ 'redis', # parallel 'mpire>=2.3.5', # parallel 'einops', + 'transformers', + 'datasets', ], extras_require={ 'test': [