Skip to content

Implements the classical DTW algorithm in PyTorch, enabling multi-pattern matching and GPU acceleration

Notifications You must be signed in to change notification settings

MicahSee/PyTorchDTW

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PyTorchDTW

Implements the classical DTW algorithm in PyTorch, enabling multi-pattern matching and GPU acceleration.

To run the DTW module on a GPU, you can use the code below:

from pytorch_dtw import DTW

if torch.cuda.is_available():
  device = torch.device("cuda")
else:
  device = torch.device("cpu")

dtw = DTW(device=device).to(device)

The DTW module currently supports matching one sequence against many other sequences. The inputs to the forward() method must have shape (1, n) and (k, m), where n is the length of the sequence to perform matching for, k is the number of patterns to match against, and m is the length of each pattern being matched against (n and m can be different).

For example:

key = np.random.rand(1, 50)
patterns = np.random.rand(10, 75)

key_tensor = torch.tensor(key).to(device)
patterns_tensor = torch.tensor(patterns).to(device)

costs = dtw(key, patterns)

About

Implements the classical DTW algorithm in PyTorch, enabling multi-pattern matching and GPU acceleration

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published