Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update train.py to support mutilprocessing and torch.compile #282

Open
wants to merge 1 commit into
base: old_gpt_2_chinese_before_2021_4_22
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 36 additions & 24 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,40 @@
from torch.nn import DataParallel
from tokenizations.bpe_tokenizer import get_encoder

from concurrent.futures import ProcessPoolExecutor as ppe
from multiprocessing import cpu_count

def process_subline(sublines,full_tokenizer,min_length,tokenized_data_path,idx):
##low speed from below
sublines = [full_tokenizer.tokenize(line) for line in sublines if
len(line) > min_length] # 只考虑长度超过min_length的句子
sublines = [full_tokenizer.convert_tokens_to_ids(line) for line in sublines]
full_line = []
for subline in sublines:
full_line.append(full_tokenizer.convert_tokens_to_ids('[MASK]')) # 文章开头添加MASK表示文章开始
full_line.extend(subline)
full_line.append(full_tokenizer.convert_tokens_to_ids('[CLS]')) # 文章之间添加CLS表示文章结束
with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(idx), 'w') as f:
for id in full_line:
f.write(str(id) + ' ')

def build_files(data_path, tokenized_data_path, num_pieces, full_tokenizer, min_length):
with open(data_path, 'r', encoding='utf8') as f:
print('reading lines')
lines = json.load(f)
lines = [line.replace('\n', ' [SEP] ') for line in lines] # 用[SEP]表示换行, 段落之间使用SEP表示段落结束
all_len = len(lines)
if not os.path.exists(tokenized_data_path):
os.mkdir(tokenized_data_path)
for i in tqdm(range(num_pieces)):
sublines = lines[all_len // num_pieces * i: all_len // num_pieces * (i + 1)]
if i == num_pieces - 1:
sublines.extend(lines[all_len // num_pieces * (i + 1):]) # 把尾部例子添加到最后一个piece
sublines = [full_tokenizer.tokenize(line) for line in sublines if
len(line) > min_length] # 只考虑长度超过min_length的句子
sublines = [full_tokenizer.convert_tokens_to_ids(line) for line in sublines]
full_line = []
for subline in sublines:
full_line.append(full_tokenizer.convert_tokens_to_ids('[MASK]')) # 文章开头添加MASK表示文章开始
full_line.extend(subline)
full_line.append(full_tokenizer.convert_tokens_to_ids('[CLS]')) # 文章之间添加CLS表示文章结束
with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'w') as f:
for id in full_line:
f.write(str(id) + ' ')
print('finish')

with open(data_path, 'r', encoding='utf8') as f:
print('reading lines')
lines = json.load(f)
lines = [line.replace('\n', ' [SEP] ') for line in lines] # 用[SEP]表示换行, 段落之间使用SEP表示段落结束
all_len = len(lines)
if not os.path.exists(tokenized_data_path):
os.mkdir(tokenized_data_path)

with ppe(max_workers=cpu_count()) as pool:
for i in tqdm(range(num_pieces)):
sublines = lines[all_len // num_pieces * i: all_len // num_pieces * (i + 1)]
if i == num_pieces - 1:
sublines.extend(lines[all_len // num_pieces * (i + 1):]) # 把尾部例子添加到最后一个piece
pool.submit(process_subline,sublines,full_tokenizer,min_length,tokenized_data_path,i)
pool.shutdown(wait=True)
print('finish')

def main():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -67,6 +75,7 @@ def main():
parser.add_argument('--bpe_token', action='store_true', help='subword')
parser.add_argument('--encoder_json', default="tokenizations/encoder.json", type=str, help="encoder.json")
parser.add_argument('--vocab_bpe', default="tokenizations/vocab.bpe", type=str, help="vocab.bpe")
parser.add_argument('--compile',action='store_true',help='启用编译')

args = parser.parse_args()
print('args:\n' + args.__repr__())
Expand Down Expand Up @@ -122,6 +131,9 @@ def main():
model = transformers.modeling_gpt2.GPT2LMHeadModel(config=model_config)
else:
model = transformers.modeling_gpt2.GPT2LMHeadModel.from_pretrained(args.pretrained_model)
if args.compile:
torch.set_float32_matmul_precision('high')
model=torch.compile(model,mode="max-autotune")
model.train()
model.to(device)

Expand Down