Skip to content

Commit

Permalink
fix ray delete worker
Browse files Browse the repository at this point in the history
  • Loading branch information
delfosseaurelien committed Sep 17, 2021
1 parent f81fb3b commit 3d22948
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion biotransformers/wrappers/transformers_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,7 @@ def compute_embeddings(
_, embeddings = self._model_evaluation(
inputs, batch_size=batch_size, silent=silent
)
self.delete_ray_workers()
embeddings = [emb.cpu().numpy() for emb in embeddings]
# Remove class token and padding
# Use tranpose to filter on the two last dimensions. Doing this, we don't have to manage
Expand All @@ -747,7 +748,7 @@ def compute_embeddings(
embeddings_dict["mean"] = np.stack(
[e.transpose().mean(1).transpose() for e in filtered_embeddings]
)
self.delete_ray_workers()

return embeddings_dict

def compute_accuracy(
Expand Down

0 comments on commit 3d22948

Please sign in to comment.