From f55e14cce8ea305e72a1b97b1040e646b27642c6 Mon Sep 17 00:00:00 2001 From: Alex Garcia Date: Sun, 17 Nov 2024 08:38:57 -0800 Subject: [PATCH] ann-filtering-benchmark directory --- .gitignore | 2 + sqlite-vec.c | 21 ++-- tests/afbd/.gitignore | 1 + tests/afbd/.python-version | 1 + tests/afbd/Makefile | 9 ++ tests/afbd/README.md | 12 +++ tests/afbd/test-afbd.py | 215 +++++++++++++++++++++++++++++++++++++ 7 files changed, 251 insertions(+), 10 deletions(-) create mode 100644 tests/afbd/.gitignore create mode 100644 tests/afbd/.python-version create mode 100644 tests/afbd/Makefile create mode 100644 tests/afbd/README.md create mode 100644 tests/afbd/test-afbd.py diff --git a/.gitignore b/.gitignore index ef7a661..ad7d0d0 100644 --- a/.gitignore +++ b/.gitignore @@ -26,3 +26,5 @@ sqlite-vec.h tmp/ poetry.lock + +*.jsonl diff --git a/sqlite-vec.c b/sqlite-vec.c index 69ec6d2..2687d79 100644 --- a/sqlite-vec.c +++ b/sqlite-vec.c @@ -5972,51 +5972,52 @@ int vec0_set_metadata_filter_bitmap( } break; } - case VEC0_METADATA_OPERATOR_GT: { + case VEC0_METADATA_OPERATOR_NE: { for(int i = 0; i < size; i++) { u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; int n = ((int*) view)[0]; char * s = (char *) &view[4]; - bitmap_set(b, i, strncmp(s, target, n) > 0); + bitmap_set(b, i, strncmp(s, target, n) != 0); } break; } - case VEC0_METADATA_OPERATOR_LE: { + case VEC0_METADATA_OPERATOR_GT: { for(int i = 0; i < size; i++) { u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; int n = ((int*) view)[0]; char * s = (char *) &view[4]; - bitmap_set(b, i, strncmp(s, target, n) <= 0); + bitmap_set(b, i, strncmp(s, target, n) > 0); } break; } - case VEC0_METADATA_OPERATOR_LT: { + case VEC0_METADATA_OPERATOR_GE: { for(int i = 0; i < size; i++) { u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; int n = ((int*) view)[0]; char * s = (char *) &view[4]; - bitmap_set(b, i, strncmp(s, target, n) < 0); + bitmap_set(b, i, strncmp(s, target, n) >= 0); } break; } - case VEC0_METADATA_OPERATOR_GE: { + case VEC0_METADATA_OPERATOR_LE: { for(int i = 0; i < size; i++) { u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; int n = ((int*) view)[0]; char * s = (char *) &view[4]; - bitmap_set(b, i, strncmp(s, target, n) >= 0); + bitmap_set(b, i, strncmp(s, target, n) <= 0); } break; } - case VEC0_METADATA_OPERATOR_NE: { + case VEC0_METADATA_OPERATOR_LT: { for(int i = 0; i < size; i++) { u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH]; int n = ((int*) view)[0]; char * s = (char *) &view[4]; - bitmap_set(b, i, strncmp(s, target, n) != 0); + bitmap_set(b, i, strncmp(s, target, n) < 0); } break; } + } break; } diff --git a/tests/afbd/.gitignore b/tests/afbd/.gitignore new file mode 100644 index 0000000..aa1ec1e --- /dev/null +++ b/tests/afbd/.gitignore @@ -0,0 +1 @@ +*.tgz diff --git a/tests/afbd/.python-version b/tests/afbd/.python-version new file mode 100644 index 0000000..e4fba21 --- /dev/null +++ b/tests/afbd/.python-version @@ -0,0 +1 @@ +3.12 diff --git a/tests/afbd/Makefile b/tests/afbd/Makefile new file mode 100644 index 0000000..083b429 --- /dev/null +++ b/tests/afbd/Makefile @@ -0,0 +1,9 @@ +random_ints_1m.tgz: + curl -o $@ https://storage.googleapis.com/ann-filtered-benchmark/datasets/random_ints_1m.tgz + +random_float_1m.tgz: + curl -o $@ https://storage.googleapis.com/ann-filtered-benchmark/datasets/random_float_1m.tgz + +random_keywords_1m.tgz: + curl -o $@ https://storage.googleapis.com/ann-filtered-benchmark/datasets/random_keywords_1m.tgz +all: random_ints_1m.tgz random_float_1m.tgz random_keywords_1m.tgz diff --git a/tests/afbd/README.md b/tests/afbd/README.md new file mode 100644 index 0000000..be7b6e5 --- /dev/null +++ b/tests/afbd/README.md @@ -0,0 +1,12 @@ + +# hnm + +``` +tar -xOzf hnm.tgz ./tests.jsonl > tests.jsonl +solite q "select group_concat(distinct key) from lines_read('tests.jsonl'), json_each(line -> '$.conditions.and[0]')" +``` + + +``` +> python test-afbd.py build hnm.tgz --metadata product_group_name,colour_group_name,index_group_name,perceived_colour_value_name,section_name,product_type_name,department_name,graphical_appearance_name,garment_group_name,perceived_colour_master_name +``` diff --git a/tests/afbd/test-afbd.py b/tests/afbd/test-afbd.py new file mode 100644 index 0000000..098af40 --- /dev/null +++ b/tests/afbd/test-afbd.py @@ -0,0 +1,215 @@ +import numpy as np +from tqdm import tqdm +from deepdiff import DeepDiff + +import tarfile +import json +from io import BytesIO +import sqlite3 +from typing import List +from struct import pack +import time +from pathlib import Path +import argparse + + +def serialize_float32(vector: List[float]) -> bytes: + """Serializes a list of floats into the "raw bytes" format sqlite-vec expects""" + return pack("%sf" % len(vector), *vector) + + +def build_command(file_path, metadata_set=None): + if metadata_set: + metadata_set = set(metadata_set.split(",")) + + file_path = Path(file_path) + print(f"reading {file_path}...") + t0 = time.time() + with tarfile.open(file_path, "r:gz") as archive: + for file in archive: + if file.name == "./payloads.jsonl": + payloads = [ + json.loads(line) + for line in archive.extractfile(file.name).readlines() + ] + if file.name == "./tests.jsonl": + tests = [ + json.loads(line) + for line in archive.extractfile(file.name).readlines() + ] + if file.name == "./vectors.npy": + f = BytesIO() + f.write(archive.extractfile(file.name).read()) + f.seek(0) + vectors = np.load(f) + + assert payloads is not None + assert tests is not None + assert vectors is not None + dimensions = vectors.shape[1] + metadata_columns = sorted(list(payloads[0].keys())) + + def col_type(v): + if isinstance(v, int): + return "integer" + if isinstance(v, float): + return "float" + if isinstance(v, str): + return "text" + raise Exception(f"Unknown column type: {v}") + + metadata_columns_types = [col_type(payloads[0][col]) for col in metadata_columns] + + print(time.time() - t0) + t0 = time.time() + print("seeding...") + + db = sqlite3.connect(f"{file_path.stem}.db") + db.execute("PRAGMA page_size = 16384") + db.row_factory = sqlite3.Row + db.enable_load_extension(True) + db.load_extension("../../dist/vec0") + db.enable_load_extension(False) + + with db: + db.execute("create table tests(data)") + + for test in tests: + db.execute("insert into tests values (?)", [json.dumps(test)]) + + with db: + create_sql = f"create virtual table v using vec0(vector float[{dimensions}] distance_metric=cosine" + insert_sql = "insert into v(rowid, vector" + for name, type in zip(metadata_columns, metadata_columns_types): + if metadata_set: + if name in metadata_set: + create_sql += f", {name} {type}" + else: + create_sql += f", +{name} {type}" + else: + create_sql += f", {name} {type}" + + insert_sql += f", {name}" + create_sql += ")" + insert_sql += ") values (" + ",".join("?" * (2 + len(metadata_columns))) + ")" + print(create_sql) + print(insert_sql) + + db.execute(create_sql) + + for idx, (payload, vector) in enumerate( + tqdm(zip(payloads, vectors), total=len(payloads)) + ): + params = [idx, vector] + for c in metadata_columns: + params.append(payload[c]) + db.execute(insert_sql, params) + + print(time.time() - t0) + + +def tests_command(file_path): + file_path = Path(file_path) + db = sqlite3.connect(f"{file_path.stem}.db") + db.execute("PRAGMA cache_size = -100000000") + db.row_factory = sqlite3.Row + db.enable_load_extension(True) + db.load_extension("../../dist/vec0") + db.enable_load_extension(False) + + tests = [ + json.loads(row["data"]) + for row in db.execute("select data from tests limit 2000").fetchall() + ] + + num_or_skips = 0 + num_1off_errors = 0 + + t0 = time.time() + print("testing...") + for idx, test in enumerate(tqdm(tests)): + query = test["query"] + conditions = test["conditions"] + expected_closest_ids = test["closest_ids"] + expected_closest_scores = test["closest_scores"] + if "or" in conditions: + num_or_skips += 1 + continue + + sql = "select rowid, 1 - distance as similarity from v where vector match ? and k = ?" + params = [serialize_float32(query), len(expected_closest_ids)] + + for condition in conditions["and"]: + assert len(condition.keys()) == 1 + column = list(condition.keys())[0] + assert len(list(condition[column].keys())) == 1 + condition_type = list(condition[column].keys())[0] + if condition_type == "match": + value = condition[column]["match"]["value"] + sql += f" and {column} = ?" + params.append(value) + elif condition_type == "range": + sql += f" and {column} between ? and ?" + params.append(condition[column]["range"]["gt"]) + params.append(condition[column]["range"]["lt"]) + else: + raise Exception(f"Unknown condition type: {condition_type}") + + rows = db.execute(sql, params).fetchall() + actual_closest_ids = [row["rowid"] for row in rows] + matches = expected_closest_ids == actual_closest_ids + if not matches: + diff = DeepDiff( + expected_closest_ids, actual_closest_ids, ignore_order=False + ) + assert len(list(diff.keys())) == 1 + assert "values_changed" in diff.keys() + keys_changed = list(diff["values_changed"].keys()) + if len(keys_changed) == 2: + akey, bkey = keys_changed + a = int(akey.lstrip("root[").rstrip("]")) + b = int(bkey.lstrip("root[").rstrip("]")) + assert abs(a - b) == 1 + assert ( + diff["values_changed"][akey]["new_value"] + == diff["values_changed"][bkey]["old_value"] + ) + assert ( + diff["values_changed"][akey]["old_value"] + == diff["values_changed"][bkey]["new_value"] + ) + elif len(keys_changed) == 1: + v = int(akey.lstrip("root[").rstrip("]")) + assert v == len(expected_closest_ids) + else: + raise Exception("fuck") + num_1off_errors += 1 + # print(closest_scores) + # print([row["similarity"] for row in rows]) + # assert closest_scores == [row["similarity"] for row in rows] + print("Number skipped: ", num_or_skips) + print("Num 1 off errors: ", num_1off_errors) + print("1 off error rate: ", num_1off_errors / (len(tests) - num_or_skips)) + print(time.time() - t0) + print("done") + + +def main(): + parser = argparse.ArgumentParser(description="CLI tool") + subparsers = parser.add_subparsers(dest="command", required=True) + + build_parser = subparsers.add_parser("build") + build_parser.add_argument("file", type=str, help="Path to input file") + build_parser.add_argument("--metadata", type=str, help="Metadata columns") + build_parser.set_defaults(func=lambda args: build_command(args.file, args.metadata)) + + tests_parser = subparsers.add_parser("test") + tests_parser.add_argument("file", type=str, help="Path to input file") + tests_parser.set_defaults(func=lambda args: tests_command(args.file)) + + args = parser.parse_args() + args.func(args) + + +if __name__ == "__main__": + main()