This repository has been archived by the owner on Feb 24, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathbaseline_tfidf.py
executable file
·64 lines (45 loc) · 1.81 KB
/
baseline_tfidf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#!/usr/bin/env python3
import os
import warnings
import numpy as np
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_distances
def read_explanations(path):
header = []
uid = None
df = pd.read_csv(path, sep='\t', dtype=str)
for name in df.columns:
if name.startswith('[SKIP]'):
if 'UID' in name and not uid:
uid = name
else:
header.append(name)
if not uid or len(df) == 0:
warnings.warn('Possibly misformatted file: ' + path)
return []
return df.apply(lambda r: (r[uid], ' '.join(str(s) for s in list(r[header]) if not pd.isna(s))), 1).tolist()
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-n', '--nearest', type=int, default=10)
parser.add_argument('tables')
parser.add_argument('questions', type=argparse.FileType('r', encoding='UTF-8'))
args = parser.parse_args()
explanations = []
for path, _, files in os.walk(args.tables):
for file in files:
explanations += read_explanations(os.path.join(path, file))
if not explanations:
warnings.warn('Empty explanations')
df_q = pd.read_csv(args.questions, sep='\t', dtype=str)
df_e = pd.DataFrame(explanations, columns=('uid', 'text'))
vectorizer = TfidfVectorizer().fit(df_q['Question']).fit(df_e['text'])
X_q = vectorizer.transform(df_q['Question'])
X_e = vectorizer.transform(df_e['text'])
X_dist = cosine_distances(X_q, X_e)
for i_question, distances in enumerate(X_dist):
for i_explanation in np.argsort(distances)[:args.nearest]:
print('{}\t{}'.format(df_q.loc[i_question]['questionID'], df_e.loc[i_explanation]['uid']))
if '__main__' == __name__:
main()