Skip to content

Commit 34ee45f

Browse files
committed
feat: update _cuda_recurrence to use is_floating_point and refactor parallel reduction import
1 parent 8c19673 commit 34ee45f

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

torchlpc/recurrence.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@ def _cuda_recurrence(
4646
impulse: torch.Tensor, decay: torch.Tensor, initial_state: torch.Tensor
4747
) -> torch.Tensor:
4848
n_dims, n_steps = decay.shape
49-
if impulse.dtype is torch.float32:
49+
if impulse.is_floating_point():
5050
try:
51-
from pararnn.parallel_reduction.parallel_reduction import ParallelSolve
51+
import pararnn.parallel_reduction.parallel_reduction
5252

53-
return ParallelSolve.parallel_reduce_diag_cuda(
53+
return torch.ops.parallel_reduce_cuda.parallel_reduce_diag_cuda(
5454
F.pad(-decay, (1, 0)),
5555
torch.cat([initial_state.unsqueeze(1), impulse], dim=1),
5656
)[:, 1:]

0 commit comments

Comments
 (0)