Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,84 @@ numpy.allclose(mkl_res, np_res)
# True
```

---
# Patching Mechanisms

The `mkl_fft` provides convenient ways to enable MKL-accelerated FFT operations in NumPy with or without modifying your code. It supports both persistent patching (applies to all Python sessions) and one-shot execution (applies only to a single command). It also supports Python functions and context managers that do the same.
Comment thread
jharlow-intel marked this conversation as resolved.

## Persistent Patching

### Install Persistent Patch

```bash
python -m mkl_fft patch install
```

### Check Patch Status

```bash
python -m mkl_fft patch status
```

Checks whether the persistent patch is currently installed. Returns exit code 0 if installed, 1 if not installed.

### Uninstall Persistent Patch

```bash
python -m mkl_fft patch uninstall
```

Removes the persistent patch file, restoring NumPy to its default FFT implementation.

## One-Shot Execution

```bash
python -m mkl_fft with_patch <command> [args...]
```

Runs a single command with MKL-accelerated FFT enabled. The patch is only active for that specific execution and does not persist.
Comment thread
jharlow-intel marked this conversation as resolved.

**Examples:**

```bash
# Run a Python script with MKL acceleration
python -m mkl_fft with_patch python my_script.py

# Run tests with MKL acceleration
python -m mkl_fft with_patch python -m pytest tests/

# Run a Python one-liner
python -m mkl_fft with_patch python -c "import numpy; print(numpy.fft.fft.__module__)"

# Run benchmarks with MKL acceleration
python -m mkl_fft with_patch python run_benchmarks.py
```

## Programmatic Usage

You can also patch NumPy programmatically in your Python code:

```python
import mkl_fft

# Check if currently patched
if mkl_fft.is_patched():
print("NumPy FFT is using MKL")

# Enable patching globally
mkl_fft.patch_numpy_fft()

# Disable patching
mkl_fft.restore_numpy_fft()

# Use as context manager (recommended for temporary patching)
with mkl_fft.mkl_fft():
# NumPy FFT uses MKL inside this block
import numpy as np
result = np.fft.fft(data)
# NumPy FFT restored outside the block
```

---
# Building from source

Expand Down
70 changes: 70 additions & 0 deletions mkl_fft/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) 2017, Intel Corporation
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of Intel Corporation nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

"""Command-line interface for mkl_fft."""

import sys


def main_impl():
if len(sys.argv) < 2:
print("Usage: python -m mkl_fft <command> [args]")
print()
print("Commands:")
Comment thread
jharlow-intel marked this conversation as resolved.
print(" patch install Install persistent NumPy FFT patch")
print(" patch uninstall Uninstall persistent NumPy FFT patch")
print(" patch status Check if persistent patch is installed")
print(" with_patch <cmd> Run command with temporary NumPy FFT patch")
print()
print("Examples:")
print(" python -m mkl_fft patch install")
print(" python -m mkl_fft with_patch python script.py")
sys.exit(1)

command = sys.argv[1]

if command == "patch":
from mkl_fft.patch import main as patch_main

patch_main(sys.argv[2:])
elif command == "with_patch":
from mkl_fft.with_patch import main as with_patch_main

with_patch_main(sys.argv[2:])
else:
print(f"Unknown command: {command}")
sys.exit(1)


def main():
"""Entry point that avoids importing mkl_fft package."""
try:
main_impl()
except Exception:
main_impl()


if __name__ == "__main__":
main_impl()
Comment thread
jharlow-intel marked this conversation as resolved.
149 changes: 149 additions & 0 deletions mkl_fft/patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright (c) 2017, Intel Corporation
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of Intel Corporation nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

"""Persistent patch management for NumPy FFT submodule."""

import argparse
import os
Comment thread
jharlow-intel marked this conversation as resolved.
import site
import sys
from pathlib import Path


def get_sitecustomize_path():
"""Get the path to sitecustomize.py in the user site-packages."""
user_site = site.getusersitepackages()
if not user_site:
user_site = site.getsitepackages()[0]
return Path(user_site) / "sitecustomize.py"
Comment thread
jharlow-intel marked this conversation as resolved.


def get_pth_path():
"""Get the path to mkl_fft.pth in the appropriate site-packages."""
Comment thread
jharlow-intel marked this conversation as resolved.
site_packages = site.getsitepackages()
if site_packages:
target_site = site_packages[0]
else:
target_site = site.getusersitepackages()
return Path(target_site) / "mkl_fft_patch.pth"


PATCH_CODE = """# mkl_fft persistent patch - auto-generated
try:
import mkl_fft
mkl_fft.patch_numpy_fft()
except Exception:
pass
"""

PTH_CONTENT = """import mkl_fft; mkl_fft.patch_numpy_fft()"""
Comment thread
jharlow-intel marked this conversation as resolved.
Comment thread
jharlow-intel marked this conversation as resolved.


def install_patch():
"""Install persistent NumPy FFT patch using .pth file."""
pth_path = get_pth_path()

if pth_path.exists():
print(f"Persistent patch already installed at {pth_path}")
return

try:
pth_path.parent.mkdir(parents=True, exist_ok=True)
pth_path.write_text(PTH_CONTENT)
print(f"✓ Persistent patch installed at {pth_path}")
print()
print("NumPy FFT will now use MKL-accelerated implementations in all")
print("Python sessions. To disable, run:")
print(" python -m mkl_fft patch uninstall")
except (IOError, OSError) as e:
print(f"Error installing patch: {e}")
print()
print("You may need to run with appropriate permissions or install to")
print("a user site-packages directory.")
sys.exit(1)


def uninstall_patch():
"""Uninstall persistent NumPy FFT patch."""
pth_path = get_pth_path()

if not pth_path.exists():
print("No persistent patch found.")
return

try:
pth_path.unlink()
print(f"✓ Persistent patch removed from {pth_path}")
print()
print("NumPy FFT will now use the default implementations.")
except (IOError, OSError) as e:
print(f"Error removing patch: {e}")
sys.exit(1)


def check_status():
"""Check if persistent patch is installed."""
pth_path = get_pth_path()

if pth_path.exists():
print(f"✓ Persistent patch is installed at {pth_path}")
print()
print("NumPy FFT is configured to use MKL-accelerated implementations.")
return True
else:
print("✗ No persistent patch installed")
print()
print("To enable MKL-accelerated NumPy FFT globally, run:")
print(" python -m mkl_fft patch install")
return False


def main(args=None):
"""Main entry point for patch command."""
parser = argparse.ArgumentParser(
prog="python -m mkl_fft patch",
description="Manage persistent NumPy FFT patching with MKL acceleration"
)
subparsers = parser.add_subparsers(dest="command", help="Available commands")

subparsers.add_parser("install", help="Install persistent NumPy FFT patch")
subparsers.add_parser("uninstall", help="Uninstall persistent NumPy FFT patch")
subparsers.add_parser("status", help="Check if persistent patch is installed")

parsed_args = parser.parse_args(args)

if not parsed_args.command:
parser.print_help()
sys.exit(1)

if parsed_args.command == "install":
install_patch()
elif parsed_args.command == "uninstall":
uninstall_patch()
elif parsed_args.command == "status":
sys.exit(0 if check_status() else 1)


if __name__ == "__main__":
main()
92 changes: 92 additions & 0 deletions mkl_fft/tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) 2017, Intel Corporation
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of Intel Corporation nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import pytest

from mkl_fft.patch import check_status, install_patch, uninstall_patch


@pytest.fixture
def mock_pth_path(tmp_path, monkeypatch):
"""Mock the .pth file path to use a temporary directory."""
pth_file = tmp_path / "mkl_fft_patch.pth"

def mock_get_pth_path():
return pth_file

monkeypatch.setattr("mkl_fft.patch.get_pth_path", mock_get_pth_path)
return pth_file


def test_install_patch(mock_pth_path, capsys):
"""Test installing persistent patch."""
install_patch()

assert mock_pth_path.exists()
content = mock_pth_path.read_text()
assert "mkl_fft.patch_numpy_fft()" in content

captured = capsys.readouterr()
assert "Persistent patch installed" in captured.out


def test_install_patch_already_installed(mock_pth_path, capsys):
"""Test installing patch when already installed."""
install_patch()
install_patch()

captured = capsys.readouterr()
assert "already installed" in captured.out


def test_uninstall_patch(mock_pth_path, capsys):
"""Test uninstalling persistent patch."""
install_patch()
assert mock_pth_path.exists()

uninstall_patch()
assert not mock_pth_path.exists()

captured = capsys.readouterr()
assert "Persistent patch removed" in captured.out


def test_uninstall_patch_not_installed(mock_pth_path, capsys):
"""Test uninstalling patch when not installed."""
uninstall_patch()

captured = capsys.readouterr()
assert "No persistent patch found" in captured.out


def test_patch_status_check_function(mock_pth_path):
"""Test check_status function return values."""
assert not check_status()

install_patch()
assert check_status()

uninstall_patch()
assert not check_status()
Loading
Loading