Skip to content

Commit e9f57de

Browse files
committed
CI: Added tests for ml orchestration scripts
1 parent aa6f3c3 commit e9f57de

4 files changed

Lines changed: 232 additions & 5 deletions

File tree

.github/workflows/test.yml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,5 @@ jobs:
4242

4343
- name: Run tests
4444
run: |
45-
pytest tests/ -v --cov=src --cov-report=xml --cov-report=term
45+
pytest tests/ -v --cov=src --cov=ml --cov-report=xml --cov-report=term
4646
shell: bash -l {0}
47-
48-
49-
50-

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ addopts = """
4545
--tb=short
4646
--doctest-modules
4747
--cov=src/quantlab
48+
--cov=ml
4849
--cov-report=term-missing
4950
--cov-report=html
5051
"""
@@ -55,6 +56,11 @@ markers = [
5556
"slow: Tests that take >1s (skip by default)",
5657
]
5758

59+
[tool.coverage.run]
60+
source = ["src", "ml"]
61+
omit = ["tests/*", "*/__pycache__/*", "*.pyc"]
62+
disable_warnings = ["no-data-collected", "module-not-imported"]
63+
5864
[tool.black]
5965
line-length = 88
6066
target-version = ['py38']

tests/ml/test_metrics/test_pnl.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
"""Tests for P&L computation functions."""
2+
3+
import pytest
4+
import torch
5+
6+
from quantlab.ml.metrics.pnl import compute_pnl_with_tx, decompose_pnl
7+
8+
9+
def test_compute_pnl_with_tx_basic():
10+
"""Test basic P&L calculation."""
11+
S = torch.tensor([[100.0, 105.0, 110.0]]) # Single path, 3 time steps
12+
K = 105.0
13+
phi = torch.tensor([[0.5, 0.6]]) # Holdings at t0, t1
14+
lambda_tx = 0.01
15+
16+
pnl = compute_pnl_with_tx(S, K, phi, lambda_tx)
17+
18+
expected_payoff = max(110.0 - 105.0, 0) # 5.0
19+
expected_hedging_gain = 0.5 * (105.0 - 100.0) + 0.6 * (
20+
110.0 - 105.0
21+
) # 2.5 + 3.0 = 5.5
22+
expected_tx_cost = 0.01 * abs(0.6 - 0.5) * 105.0 # 0.01 * 0.1 * 105 = 0.105
23+
expected_pnl = expected_payoff - expected_hedging_gain - expected_tx_cost
24+
25+
assert torch.isclose(pnl[0], torch.tensor(expected_pnl), atol=1e-5)
26+
27+
28+
def test_compute_pnl_with_tx_no_transactions():
29+
"""Test P&L with no trading."""
30+
S = torch.tensor([[100.0, 105.0, 110.0]])
31+
K = 105.0
32+
phi = torch.tensor([[0.5, 0.5]]) # No change in holdings
33+
lambda_tx = 0.01
34+
35+
pnl = compute_pnl_with_tx(S, K, phi, lambda_tx)
36+
37+
# Should have no transaction costs
38+
expected_payoff = 5.0
39+
expected_hedging_gain = 0.5 * (105.0 - 100.0) + 0.5 * (110.0 - 105.0) # 5.0
40+
expected_pnl = expected_payoff - expected_hedging_gain # 0.0
41+
42+
assert torch.isclose(pnl[0], torch.tensor(expected_pnl), atol=1e-5)
43+
44+
45+
def test_decompose_pnl():
46+
"""Test P&L decomposition."""
47+
S = torch.tensor([[100.0, 105.0, 110.0]])
48+
K = 105.0
49+
phi = torch.tensor([[0.5, 0.6]])
50+
lambda_tx = 0.01
51+
52+
total_pnl, hedging_gain, tx_cost = decompose_pnl(S, K, phi, lambda_tx)
53+
54+
assert total_pnl.shape == (1,)
55+
assert hedging_gain.shape == (1,)
56+
assert tx_cost.shape == (1,)
57+
58+
# Verify decomposition
59+
pnl_direct = compute_pnl_with_tx(S, K, phi, lambda_tx)
60+
assert torch.allclose(total_pnl, pnl_direct)
61+
62+
63+
def test_pnl_edge_cases():
64+
"""Test edge cases."""
65+
# Out-of-the-money option
66+
S = torch.tensor([[100.0, 95.0, 90.0]])
67+
K = 105.0
68+
phi = torch.tensor([[0.0, 0.0]])
69+
lambda_tx = 0.0
70+
71+
pnl = compute_pnl_with_tx(S, K, phi, lambda_tx)
72+
assert torch.isclose(pnl[0], torch.tensor(0.0)) # Zero payoff, zero hedging gain
73+
74+
# At-the-money
75+
S = torch.tensor([[100.0, 100.0, 100.0]])
76+
K = 100.0
77+
pnl = compute_pnl_with_tx(S, K, phi, lambda_tx)
78+
assert torch.isclose(pnl[0], torch.tensor(0.0))
79+
80+
81+
@pytest.mark.parametrize(
82+
"shape",
83+
[
84+
((1, 2), (1, 1)), # Minimal case
85+
((10, 5), (10, 4)), # Multiple paths
86+
((1, 100), (1, 99)), # Long time series
87+
],
88+
)
89+
def test_pnl_shapes(shape):
90+
"""Test various input shapes."""
91+
S_shape, phi_shape = shape
92+
S = torch.randn(S_shape)
93+
K = 100.0
94+
phi = torch.rand(phi_shape)
95+
lambda_tx = 0.01
96+
97+
pnl = compute_pnl_with_tx(S, K, phi, lambda_tx)
98+
assert pnl.shape[0] == S_shape[0] # Same number of samples
99+
100+
101+
def test_pnl_gradient_flow():
102+
"""Test that gradients flow through P&L computation."""
103+
S = torch.tensor([[100.0, 105.0, 110.0]], requires_grad=True)
104+
K = torch.tensor(105.0, requires_grad=True)
105+
phi = torch.tensor([[0.5, 0.6]], requires_grad=False)
106+
lambda_tx = 0.01
107+
108+
pnl = compute_pnl_with_tx(S, K, phi, lambda_tx)
109+
pnl.sum().backward()
110+
111+
assert S.grad is not None
112+
assert K.grad is not None

tests/ml/test_train.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
"""Tests for training pipeline - functionality verification."""
2+
3+
import subprocess
4+
import sys
5+
from pathlib import Path
6+
7+
import torch
8+
9+
10+
def test_train_cli_help():
11+
"""Test that the training CLI can be invoked."""
12+
project_root = Path(__file__).parent.parent.parent
13+
result = subprocess.run(
14+
[
15+
sys.executable,
16+
"-c",
17+
"""
18+
import sys
19+
sys.path.insert(0, '.')
20+
import ml.train
21+
import argparse
22+
parser = argparse.ArgumentParser(description='Deep Hedging')
23+
parser.add_argument('--data_type_model', type=str, default='gbm')
24+
print('Parser created successfully')
25+
""",
26+
],
27+
capture_output=True,
28+
text=True,
29+
cwd=project_root,
30+
)
31+
32+
assert result.returncode == 0, f"Failed to import train module: {result.stderr}"
33+
34+
35+
def test_prepare_inputs_logic():
36+
"""Test the prepare_inputs logic directly in the test."""
37+
38+
def prepare_inputs(S, K, T, M, device="cpu"):
39+
"""Prepare input batches for training."""
40+
t_grid = torch.linspace(0, T - T / M, M, device=device)
41+
tau_grid = T - t_grid
42+
N = S.size(0)
43+
tau_batch = tau_grid.unsqueeze(0).expand(N, -1)
44+
moneyness_batch = S[:, :-1] / K
45+
return tau_batch.reshape(-1), moneyness_batch.reshape(-1)
46+
47+
S = torch.tensor([[100.0, 105.0, 110.0]]) # Shape: (1, 3) - N=1, M=2
48+
K = 105.0 # Scalar
49+
T = 1.0
50+
M = 2 # 2 time steps between 0 and T
51+
52+
tau_flat, moneyness_flat = prepare_inputs(S, K, T, M)
53+
assert tau_flat.shape == torch.Size(
54+
[2]
55+
), f"Expected shape [2], got {tau_flat.shape}"
56+
assert moneyness_flat.shape == torch.Size(
57+
[2]
58+
), f"Expected shape [2], got {moneyness_flat.shape}"
59+
60+
# Check individual values
61+
assert torch.allclose(
62+
tau_flat[0], torch.tensor(1.0), atol=1e-6
63+
), f"First tau value incorrect: {tau_flat[0]}"
64+
assert torch.allclose(
65+
tau_flat[1], torch.tensor(0.5), atol=1e-6
66+
), f"Second tau value incorrect: {tau_flat[1]}"
67+
assert torch.allclose(
68+
moneyness_flat[0], torch.tensor(100.0 / 105.0), atol=1e-6
69+
), f"First moneyness value incorrect: {moneyness_flat[0]}"
70+
assert torch.allclose(
71+
moneyness_flat[1], torch.tensor(105.0 / 105.0), atol=1e-6
72+
), f"Second moneyness value incorrect: {moneyness_flat[1]}"
73+
74+
75+
def test_train_script_syntax():
76+
"""Test that the train script has valid Python syntax."""
77+
project_root = Path(__file__).parent.parent.parent
78+
train_file = project_root / "ml" / "train.py"
79+
80+
with open(train_file, "r") as f:
81+
code = f.read()
82+
83+
# This will raise SyntaxError if there are syntax issues
84+
compile(code, str(train_file), "exec")
85+
86+
87+
def test_train_can_run_minimal():
88+
"""Test that train script can run with minimal configuration."""
89+
project_root = Path(__file__).parent.parent.parent
90+
91+
# Create a minimal test that doesn't actually train but validates imports
92+
test_code = """
93+
import sys
94+
sys.path.insert(0, '.')
95+
sys.path.insert(0, './src')
96+
97+
# Test imports work
98+
from quantlab.ml.models.hedge_net import HedgeNet
99+
from quantlab.ml.metrics.pnl import compute_pnl_with_tx
100+
101+
# Verify basic functionality
102+
net = HedgeNet(hidden_dim=16)
103+
print('Basic imports and instantiation successful')
104+
"""
105+
106+
result = subprocess.run(
107+
[sys.executable, "-c", test_code],
108+
capture_output=True,
109+
text=True,
110+
cwd=project_root,
111+
)
112+
113+
assert result.returncode == 0, f"Basic functionality test failed: {result.stderr}"

0 commit comments

Comments
 (0)