diff --git a/tools/tabpfn/main.py b/tools/tabpfn/main.py index 2b4034f2b8..240a8ddb82 100644 --- a/tools/tabpfn/main.py +++ b/tools/tabpfn/main.py @@ -61,7 +61,7 @@ def train_evaluate(args): te_labels = [] s_time = time.time() if args["selected_task"] == "Classification": - classifier = TabPFNClassifier(device="cpu") + classifier = TabPFNClassifier() classifier.fit(tr_features, tr_labels) y_eval = classifier.predict(te_features) pred_probas_test = classifier.predict_proba(te_features) @@ -81,7 +81,7 @@ def train_evaluate(args): "Precision", ) else: - regressor = TabPFNRegressor(device="cpu") + regressor = TabPFNRegressor() regressor.fit(tr_features, tr_labels) y_eval = regressor.predict(te_features) if len(te_labels) > 0: