|
| 1 | +""" |
| 2 | +Simple test to verify OAuth header mutation fix |
| 3 | +""" |
| 4 | +import asyncio |
| 5 | +from okta.http_client import HTTPClient |
| 6 | + |
| 7 | + |
| 8 | +async def test_header_mutation(): |
| 9 | + """Test that sending form data doesn't mutate shared headers""" |
| 10 | + |
| 11 | + # Initialize HTTPClient with minimal config |
| 12 | + http_config = { |
| 13 | + "headers": { |
| 14 | + "User-Agent": "test-client", |
| 15 | + "Accept": "application/json" |
| 16 | + } |
| 17 | + } |
| 18 | + http_client = HTTPClient(http_config) |
| 19 | + |
| 20 | + # Get initial default headers |
| 21 | + initial_headers = dict(http_client._default_headers) |
| 22 | + print(f"Initial headers: {initial_headers}") |
| 23 | + |
| 24 | + # Simulate an OAuth request with form data |
| 25 | + oauth_request = { |
| 26 | + "method": "POST", |
| 27 | + "url": "https://test.okta.com/oauth2/v1/token", |
| 28 | + "headers": { |
| 29 | + "Accept": "application/json", |
| 30 | + "Content-Type": "application/x-www-form-urlencoded" |
| 31 | + }, |
| 32 | + "data": None, |
| 33 | + "form": { |
| 34 | + "grant_type": "client_credentials", |
| 35 | + "client_assertion": "test_jwt_token" |
| 36 | + } |
| 37 | + } |
| 38 | + |
| 39 | + # This should NOT mutate _default_headers |
| 40 | + try: |
| 41 | + # We'll get an error since we're not actually making a request, |
| 42 | + # but we just want to check header mutation doesn't happen |
| 43 | + # in the preparation phase |
| 44 | + result = await http_client.send_request(oauth_request) |
| 45 | + except Exception as e: |
| 46 | + # Expected to fail, we're just testing header mutation |
| 47 | + pass |
| 48 | + |
| 49 | + # Check headers after the request |
| 50 | + after_headers = dict(http_client._default_headers) |
| 51 | + print(f"After headers: {after_headers}") |
| 52 | + |
| 53 | + # Verify headers weren't mutated |
| 54 | + if initial_headers == after_headers: |
| 55 | + print("✅ SUCCESS: Headers were not mutated!") |
| 56 | + print(" Shared state is preserved correctly.") |
| 57 | + return True |
| 58 | + else: |
| 59 | + print("❌ FAILURE: Headers were mutated!") |
| 60 | + print(f" Initial: {initial_headers}") |
| 61 | + print(f" After: {after_headers}") |
| 62 | + added = set(after_headers.keys()) - set(initial_headers.keys()) |
| 63 | + removed = set(initial_headers.keys()) - set(after_headers.keys()) |
| 64 | + if added: |
| 65 | + print(f" Added keys: {added}") |
| 66 | + if removed: |
| 67 | + print(f" Removed keys: {removed}") |
| 68 | + return False |
| 69 | + |
| 70 | + |
| 71 | +if __name__ == '__main__': |
| 72 | + print("Testing OAuth header mutation fix...") |
| 73 | + print("=" * 60) |
| 74 | + result = asyncio.run(test_header_mutation()) |
| 75 | + print("=" * 60) |
| 76 | + if result: |
| 77 | + print("All tests passed! ✅") |
| 78 | + else: |
| 79 | + print("Tests failed! ❌") |
| 80 | + exit(1) |
| 81 | + |
| 82 | + |
0 commit comments