diff --git a/README.md b/README.md index 769aed7e..3f350beb 100644 --- a/README.md +++ b/README.md @@ -81,6 +81,84 @@ numpy.allclose(mkl_res, np_res) # True ``` +--- +# Patching Mechanisms + +The `mkl_fft` provides convenient ways to enable MKL-accelerated FFT operations in NumPy with or without modifying your code. It supports both persistent patching (applies to all Python sessions) and one-shot execution (applies only to a single command). It also supports Python functions and context managers that do the same. + +## Persistent Patching + +### Install Persistent Patch + +```bash +python -m mkl_fft patch install +``` + +### Check Patch Status + +```bash +python -m mkl_fft patch status +``` + +Checks whether the persistent patch is currently installed. Returns exit code 0 if installed, 1 if not installed. + +### Uninstall Persistent Patch + +```bash +python -m mkl_fft patch uninstall +``` + +Removes the persistent patch file, restoring NumPy to its default FFT implementation. + +## One-Shot Execution + +```bash +python -m mkl_fft with_patch [args...] +``` + +Runs a single command with MKL-accelerated FFT enabled. The patch is only active for that specific execution and does not persist. + +**Examples:** + +```bash +# Run a Python script with MKL acceleration +python -m mkl_fft with_patch python my_script.py + +# Run tests with MKL acceleration +python -m mkl_fft with_patch python -m pytest tests/ + +# Run a Python one-liner +python -m mkl_fft with_patch python -c "import numpy; print(numpy.fft.fft.__module__)" + +# Run benchmarks with MKL acceleration +python -m mkl_fft with_patch python run_benchmarks.py +``` + +## Programmatic Usage + +You can also patch NumPy programmatically in your Python code: + +```python +import mkl_fft + +# Check if currently patched +if mkl_fft.is_patched(): + print("NumPy FFT is using MKL") + +# Enable patching globally +mkl_fft.patch_numpy_fft() + +# Disable patching +mkl_fft.restore_numpy_fft() + +# Use as context manager (recommended for temporary patching) +with mkl_fft.mkl_fft(): + # NumPy FFT uses MKL inside this block + import numpy as np + result = np.fft.fft(data) +# NumPy FFT restored outside the block +``` + --- # Building from source diff --git a/mkl_fft/__main__.py b/mkl_fft/__main__.py new file mode 100644 index 00000000..32c72247 --- /dev/null +++ b/mkl_fft/__main__.py @@ -0,0 +1,70 @@ +# Copyright (c) 2017, Intel Corporation +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of Intel Corporation nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Command-line interface for mkl_fft.""" + +import sys + + +def main_impl(): + if len(sys.argv) < 2: + print("Usage: python -m mkl_fft [args]") + print() + print("Commands:") + print(" patch install Install persistent NumPy FFT patch") + print(" patch uninstall Uninstall persistent NumPy FFT patch") + print(" patch status Check if persistent patch is installed") + print(" with_patch Run command with temporary NumPy FFT patch") + print() + print("Examples:") + print(" python -m mkl_fft patch install") + print(" python -m mkl_fft with_patch python script.py") + sys.exit(1) + + command = sys.argv[1] + + if command == "patch": + from mkl_fft.patch import main as patch_main + + patch_main(sys.argv[2:]) + elif command == "with_patch": + from mkl_fft.with_patch import main as with_patch_main + + with_patch_main(sys.argv[2:]) + else: + print(f"Unknown command: {command}") + sys.exit(1) + + +def main(): + """Entry point that avoids importing mkl_fft package.""" + try: + main_impl() + except Exception: + main_impl() + + +if __name__ == "__main__": + main_impl() diff --git a/mkl_fft/patch.py b/mkl_fft/patch.py new file mode 100644 index 00000000..3303806d --- /dev/null +++ b/mkl_fft/patch.py @@ -0,0 +1,149 @@ +# Copyright (c) 2017, Intel Corporation +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of Intel Corporation nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Persistent patch management for NumPy FFT submodule.""" + +import argparse +import os +import site +import sys +from pathlib import Path + + +def get_sitecustomize_path(): + """Get the path to sitecustomize.py in the user site-packages.""" + user_site = site.getusersitepackages() + if not user_site: + user_site = site.getsitepackages()[0] + return Path(user_site) / "sitecustomize.py" + + +def get_pth_path(): + """Get the path to mkl_fft.pth in the appropriate site-packages.""" + site_packages = site.getsitepackages() + if site_packages: + target_site = site_packages[0] + else: + target_site = site.getusersitepackages() + return Path(target_site) / "mkl_fft_patch.pth" + + +PATCH_CODE = """# mkl_fft persistent patch - auto-generated +try: + import mkl_fft + mkl_fft.patch_numpy_fft() +except Exception: + pass +""" + +PTH_CONTENT = """import mkl_fft; mkl_fft.patch_numpy_fft()""" + + +def install_patch(): + """Install persistent NumPy FFT patch using .pth file.""" + pth_path = get_pth_path() + + if pth_path.exists(): + print(f"Persistent patch already installed at {pth_path}") + return + + try: + pth_path.parent.mkdir(parents=True, exist_ok=True) + pth_path.write_text(PTH_CONTENT) + print(f"✓ Persistent patch installed at {pth_path}") + print() + print("NumPy FFT will now use MKL-accelerated implementations in all") + print("Python sessions. To disable, run:") + print(" python -m mkl_fft patch uninstall") + except (IOError, OSError) as e: + print(f"Error installing patch: {e}") + print() + print("You may need to run with appropriate permissions or install to") + print("a user site-packages directory.") + sys.exit(1) + + +def uninstall_patch(): + """Uninstall persistent NumPy FFT patch.""" + pth_path = get_pth_path() + + if not pth_path.exists(): + print("No persistent patch found.") + return + + try: + pth_path.unlink() + print(f"✓ Persistent patch removed from {pth_path}") + print() + print("NumPy FFT will now use the default implementations.") + except (IOError, OSError) as e: + print(f"Error removing patch: {e}") + sys.exit(1) + + +def check_status(): + """Check if persistent patch is installed.""" + pth_path = get_pth_path() + + if pth_path.exists(): + print(f"✓ Persistent patch is installed at {pth_path}") + print() + print("NumPy FFT is configured to use MKL-accelerated implementations.") + return True + else: + print("✗ No persistent patch installed") + print() + print("To enable MKL-accelerated NumPy FFT globally, run:") + print(" python -m mkl_fft patch install") + return False + + +def main(args=None): + """Main entry point for patch command.""" + parser = argparse.ArgumentParser( + prog="python -m mkl_fft patch", + description="Manage persistent NumPy FFT patching with MKL acceleration" + ) + subparsers = parser.add_subparsers(dest="command", help="Available commands") + + subparsers.add_parser("install", help="Install persistent NumPy FFT patch") + subparsers.add_parser("uninstall", help="Uninstall persistent NumPy FFT patch") + subparsers.add_parser("status", help="Check if persistent patch is installed") + + parsed_args = parser.parse_args(args) + + if not parsed_args.command: + parser.print_help() + sys.exit(1) + + if parsed_args.command == "install": + install_patch() + elif parsed_args.command == "uninstall": + uninstall_patch() + elif parsed_args.command == "status": + sys.exit(0 if check_status() else 1) + + +if __name__ == "__main__": + main() diff --git a/mkl_fft/tests/test_cli.py b/mkl_fft/tests/test_cli.py new file mode 100644 index 00000000..85314f49 --- /dev/null +++ b/mkl_fft/tests/test_cli.py @@ -0,0 +1,92 @@ +# Copyright (c) 2017, Intel Corporation +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of Intel Corporation nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from mkl_fft.patch import check_status, install_patch, uninstall_patch + + +@pytest.fixture +def mock_pth_path(tmp_path, monkeypatch): + """Mock the .pth file path to use a temporary directory.""" + pth_file = tmp_path / "mkl_fft_patch.pth" + + def mock_get_pth_path(): + return pth_file + + monkeypatch.setattr("mkl_fft.patch.get_pth_path", mock_get_pth_path) + return pth_file + + +def test_install_patch(mock_pth_path, capsys): + """Test installing persistent patch.""" + install_patch() + + assert mock_pth_path.exists() + content = mock_pth_path.read_text() + assert "mkl_fft.patch_numpy_fft()" in content + + captured = capsys.readouterr() + assert "Persistent patch installed" in captured.out + + +def test_install_patch_already_installed(mock_pth_path, capsys): + """Test installing patch when already installed.""" + install_patch() + install_patch() + + captured = capsys.readouterr() + assert "already installed" in captured.out + + +def test_uninstall_patch(mock_pth_path, capsys): + """Test uninstalling persistent patch.""" + install_patch() + assert mock_pth_path.exists() + + uninstall_patch() + assert not mock_pth_path.exists() + + captured = capsys.readouterr() + assert "Persistent patch removed" in captured.out + + +def test_uninstall_patch_not_installed(mock_pth_path, capsys): + """Test uninstalling patch when not installed.""" + uninstall_patch() + + captured = capsys.readouterr() + assert "No persistent patch found" in captured.out + + +def test_patch_status_check_function(mock_pth_path): + """Test check_status function return values.""" + assert not check_status() + + install_patch() + assert check_status() + + uninstall_patch() + assert not check_status() diff --git a/mkl_fft/with_patch.py b/mkl_fft/with_patch.py new file mode 100644 index 00000000..a56a94f7 --- /dev/null +++ b/mkl_fft/with_patch.py @@ -0,0 +1,92 @@ +# Copyright (c) 2017, Intel Corporation +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of Intel Corporation nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Run Python commands with temporary NumPy FFT patch.""" + +import argparse +import os +import subprocess +import sys + + +def main(args=None): + """Run a command with mkl_fft NumPy patch enabled.""" + parser = argparse.ArgumentParser( + prog="python -m mkl_fft with_patch", + description="Run a command with temporary MKL-accelerated NumPy FFT", + usage="python -m mkl_fft with_patch [args...]" + ) + parser.add_argument("command", nargs=argparse.REMAINDER, help="Command to execute with patch enabled") + + parsed_args = parser.parse_args(args) + + if not parsed_args.command: + parser.print_help() + print() + print("Examples:") + print(" python -m mkl_fft with_patch python script.py") + print(" python -m mkl_fft with_patch python -m pytest tests/") + print(" python -m mkl_fft with_patch python -c 'import numpy; print(numpy.fft.fft.__module__)'") + sys.exit(1) + + args = parsed_args.command + + patch_script = "import mkl_fft; mkl_fft.patch_numpy_fft()" + + env = os.environ.copy() + + if "PYTHONSTARTUP" in env: + existing_startup = env["PYTHONSTARTUP"] + print( + f"Warning: PYTHONSTARTUP is already set to {existing_startup}", + file=sys.stderr, + ) + print( + "The mkl_fft patch will be applied, but existing startup script will run first.", + file=sys.stderr, + ) + + import tempfile + + with tempfile.NamedTemporaryFile( + mode="w", suffix=".py", delete=False + ) as startup_file: + startup_file.write(patch_script) + startup_path = startup_file.name + + env["PYTHONSTARTUP"] = startup_path + + try: + result = subprocess.run(args, env=env) + sys.exit(result.returncode) + finally: + try: + os.unlink(startup_path) + except OSError: + pass + + +if __name__ == "__main__": + main()