Skip to content

Commit

Permalink
ann-filtering-benchmark directory
Browse files Browse the repository at this point in the history
  • Loading branch information
asg017 committed Nov 17, 2024
1 parent 052ba4b commit f55e14c
Show file tree
Hide file tree
Showing 7 changed files with 251 additions and 10 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,5 @@ sqlite-vec.h
tmp/

poetry.lock

*.jsonl
21 changes: 11 additions & 10 deletions sqlite-vec.c
Original file line number Diff line number Diff line change
Expand Up @@ -5972,51 +5972,52 @@ int vec0_set_metadata_filter_bitmap(
}
break;
}
case VEC0_METADATA_OPERATOR_GT: {
case VEC0_METADATA_OPERATOR_NE: {
for(int i = 0; i < size; i++) {
u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
int n = ((int*) view)[0];
char * s = (char *) &view[4];
bitmap_set(b, i, strncmp(s, target, n) > 0);
bitmap_set(b, i, strncmp(s, target, n) != 0);
}
break;
}
case VEC0_METADATA_OPERATOR_LE: {
case VEC0_METADATA_OPERATOR_GT: {
for(int i = 0; i < size; i++) {
u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
int n = ((int*) view)[0];
char * s = (char *) &view[4];
bitmap_set(b, i, strncmp(s, target, n) <= 0);
bitmap_set(b, i, strncmp(s, target, n) > 0);
}
break;
}
case VEC0_METADATA_OPERATOR_LT: {
case VEC0_METADATA_OPERATOR_GE: {
for(int i = 0; i < size; i++) {
u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
int n = ((int*) view)[0];
char * s = (char *) &view[4];
bitmap_set(b, i, strncmp(s, target, n) < 0);
bitmap_set(b, i, strncmp(s, target, n) >= 0);
}
break;
}
case VEC0_METADATA_OPERATOR_GE: {
case VEC0_METADATA_OPERATOR_LE: {
for(int i = 0; i < size; i++) {
u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
int n = ((int*) view)[0];
char * s = (char *) &view[4];
bitmap_set(b, i, strncmp(s, target, n) >= 0);
bitmap_set(b, i, strncmp(s, target, n) <= 0);
}
break;
}
case VEC0_METADATA_OPERATOR_NE: {
case VEC0_METADATA_OPERATOR_LT: {
for(int i = 0; i < size; i++) {
u8 * view = &((u8*) buffer)[i * VEC0_METADATA_TEXT_VIEW_BUFFER_LENGTH];
int n = ((int*) view)[0];
char * s = (char *) &view[4];
bitmap_set(b, i, strncmp(s, target, n) != 0);
bitmap_set(b, i, strncmp(s, target, n) < 0);
}
break;
}

}
break;
}
Expand Down
1 change: 1 addition & 0 deletions tests/afbd/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.tgz
1 change: 1 addition & 0 deletions tests/afbd/.python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.12
9 changes: 9 additions & 0 deletions tests/afbd/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
random_ints_1m.tgz:
curl -o $@ https://storage.googleapis.com/ann-filtered-benchmark/datasets/random_ints_1m.tgz

random_float_1m.tgz:
curl -o $@ https://storage.googleapis.com/ann-filtered-benchmark/datasets/random_float_1m.tgz

random_keywords_1m.tgz:
curl -o $@ https://storage.googleapis.com/ann-filtered-benchmark/datasets/random_keywords_1m.tgz
all: random_ints_1m.tgz random_float_1m.tgz random_keywords_1m.tgz
12 changes: 12 additions & 0 deletions tests/afbd/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

# hnm

```
tar -xOzf hnm.tgz ./tests.jsonl > tests.jsonl
solite q "select group_concat(distinct key) from lines_read('tests.jsonl'), json_each(line -> '$.conditions.and[0]')"
```


```
> python test-afbd.py build hnm.tgz --metadata product_group_name,colour_group_name,index_group_name,perceived_colour_value_name,section_name,product_type_name,department_name,graphical_appearance_name,garment_group_name,perceived_colour_master_name
```
215 changes: 215 additions & 0 deletions tests/afbd/test-afbd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
import numpy as np
from tqdm import tqdm
from deepdiff import DeepDiff

import tarfile
import json
from io import BytesIO
import sqlite3
from typing import List
from struct import pack
import time
from pathlib import Path
import argparse


def serialize_float32(vector: List[float]) -> bytes:
"""Serializes a list of floats into the "raw bytes" format sqlite-vec expects"""
return pack("%sf" % len(vector), *vector)


def build_command(file_path, metadata_set=None):
if metadata_set:
metadata_set = set(metadata_set.split(","))

file_path = Path(file_path)
print(f"reading {file_path}...")
t0 = time.time()
with tarfile.open(file_path, "r:gz") as archive:
for file in archive:
if file.name == "./payloads.jsonl":
payloads = [
json.loads(line)
for line in archive.extractfile(file.name).readlines()
]
if file.name == "./tests.jsonl":
tests = [
json.loads(line)
for line in archive.extractfile(file.name).readlines()
]
if file.name == "./vectors.npy":
f = BytesIO()
f.write(archive.extractfile(file.name).read())
f.seek(0)
vectors = np.load(f)

assert payloads is not None
assert tests is not None
assert vectors is not None
dimensions = vectors.shape[1]
metadata_columns = sorted(list(payloads[0].keys()))

def col_type(v):
if isinstance(v, int):
return "integer"
if isinstance(v, float):
return "float"
if isinstance(v, str):
return "text"
raise Exception(f"Unknown column type: {v}")

metadata_columns_types = [col_type(payloads[0][col]) for col in metadata_columns]

print(time.time() - t0)
t0 = time.time()
print("seeding...")

db = sqlite3.connect(f"{file_path.stem}.db")
db.execute("PRAGMA page_size = 16384")
db.row_factory = sqlite3.Row
db.enable_load_extension(True)
db.load_extension("../../dist/vec0")
db.enable_load_extension(False)

with db:
db.execute("create table tests(data)")

for test in tests:
db.execute("insert into tests values (?)", [json.dumps(test)])

with db:
create_sql = f"create virtual table v using vec0(vector float[{dimensions}] distance_metric=cosine"
insert_sql = "insert into v(rowid, vector"
for name, type in zip(metadata_columns, metadata_columns_types):
if metadata_set:
if name in metadata_set:
create_sql += f", {name} {type}"
else:
create_sql += f", +{name} {type}"
else:
create_sql += f", {name} {type}"

insert_sql += f", {name}"
create_sql += ")"
insert_sql += ") values (" + ",".join("?" * (2 + len(metadata_columns))) + ")"
print(create_sql)
print(insert_sql)

db.execute(create_sql)

for idx, (payload, vector) in enumerate(
tqdm(zip(payloads, vectors), total=len(payloads))
):
params = [idx, vector]
for c in metadata_columns:
params.append(payload[c])
db.execute(insert_sql, params)

print(time.time() - t0)


def tests_command(file_path):
file_path = Path(file_path)
db = sqlite3.connect(f"{file_path.stem}.db")
db.execute("PRAGMA cache_size = -100000000")
db.row_factory = sqlite3.Row
db.enable_load_extension(True)
db.load_extension("../../dist/vec0")
db.enable_load_extension(False)

tests = [
json.loads(row["data"])
for row in db.execute("select data from tests limit 2000").fetchall()
]

num_or_skips = 0
num_1off_errors = 0

t0 = time.time()
print("testing...")
for idx, test in enumerate(tqdm(tests)):
query = test["query"]
conditions = test["conditions"]
expected_closest_ids = test["closest_ids"]
expected_closest_scores = test["closest_scores"]
if "or" in conditions:
num_or_skips += 1
continue

sql = "select rowid, 1 - distance as similarity from v where vector match ? and k = ?"
params = [serialize_float32(query), len(expected_closest_ids)]

for condition in conditions["and"]:
assert len(condition.keys()) == 1
column = list(condition.keys())[0]
assert len(list(condition[column].keys())) == 1
condition_type = list(condition[column].keys())[0]
if condition_type == "match":
value = condition[column]["match"]["value"]
sql += f" and {column} = ?"
params.append(value)
elif condition_type == "range":
sql += f" and {column} between ? and ?"
params.append(condition[column]["range"]["gt"])
params.append(condition[column]["range"]["lt"])
else:
raise Exception(f"Unknown condition type: {condition_type}")

rows = db.execute(sql, params).fetchall()
actual_closest_ids = [row["rowid"] for row in rows]
matches = expected_closest_ids == actual_closest_ids
if not matches:
diff = DeepDiff(
expected_closest_ids, actual_closest_ids, ignore_order=False
)
assert len(list(diff.keys())) == 1
assert "values_changed" in diff.keys()
keys_changed = list(diff["values_changed"].keys())
if len(keys_changed) == 2:
akey, bkey = keys_changed
a = int(akey.lstrip("root[").rstrip("]"))
b = int(bkey.lstrip("root[").rstrip("]"))
assert abs(a - b) == 1
assert (
diff["values_changed"][akey]["new_value"]
== diff["values_changed"][bkey]["old_value"]
)
assert (
diff["values_changed"][akey]["old_value"]
== diff["values_changed"][bkey]["new_value"]
)
elif len(keys_changed) == 1:
v = int(akey.lstrip("root[").rstrip("]"))
assert v == len(expected_closest_ids)
else:
raise Exception("fuck")
num_1off_errors += 1
# print(closest_scores)
# print([row["similarity"] for row in rows])
# assert closest_scores == [row["similarity"] for row in rows]
print("Number skipped: ", num_or_skips)
print("Num 1 off errors: ", num_1off_errors)
print("1 off error rate: ", num_1off_errors / (len(tests) - num_or_skips))
print(time.time() - t0)
print("done")


def main():
parser = argparse.ArgumentParser(description="CLI tool")
subparsers = parser.add_subparsers(dest="command", required=True)

build_parser = subparsers.add_parser("build")
build_parser.add_argument("file", type=str, help="Path to input file")
build_parser.add_argument("--metadata", type=str, help="Metadata columns")
build_parser.set_defaults(func=lambda args: build_command(args.file, args.metadata))

tests_parser = subparsers.add_parser("test")
tests_parser.add_argument("file", type=str, help="Path to input file")
tests_parser.set_defaults(func=lambda args: tests_command(args.file))

args = parser.parse_args()
args.func(args)


if __name__ == "__main__":
main()

0 comments on commit f55e14c

Please sign in to comment.