Skip to content

Commit 55d62af

Browse files
bokelleyclaude
andcommitted
fix: constrain getattr to model_fields, add targeting to declared features
Addresses expert review feedback: - Add model_fields guard on media_buy.features and signals.features getattr (consistent with targeting namespace, prevents Pydantic internal access) - Include targeting features in get_declared_features() for better error messages - Fix test assertion (or → and) in test_require_error_lists_declared_features - Add edge case tests: empty require(), nonexistent targeting field, model_fields guard, unmapped task, success=True with data=None - Move repeated in-method imports to module level Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent ae74634 commit 55d62af

File tree

2 files changed

+75
-34
lines changed

2 files changed

+75
-34
lines changed

src/adcp/capabilities.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,15 +101,17 @@ def supports(self, feature: str) -> bool:
101101

102102
# Media buy features check
103103
if caps.media_buy is not None and caps.media_buy.features is not None:
104-
val = getattr(caps.media_buy.features, feature, None)
105-
if val is True:
106-
return True
104+
if feature in type(caps.media_buy.features).model_fields:
105+
val = getattr(caps.media_buy.features, feature, None)
106+
if val is True:
107+
return True
107108

108109
# Signals features check
109110
if caps.signals is not None and caps.signals.features is not None:
110-
val = getattr(caps.signals.features, feature, None)
111-
if val is True:
112-
return True
111+
if feature in type(caps.signals.features).model_fields:
112+
val = getattr(caps.signals.features, feature, None)
113+
if val is True:
114+
return True
113115

114116
return False
115117

@@ -163,6 +165,15 @@ def get_declared_features(self) -> list[str]:
163165
if getattr(caps.signals.features, field_name, None) is True:
164166
declared.append(field_name)
165167

168+
# Targeting features
169+
if caps.media_buy is not None and caps.media_buy.execution is not None:
170+
targeting = caps.media_buy.execution.targeting
171+
if targeting is not None:
172+
for field_name in type(targeting).model_fields:
173+
val = getattr(targeting, field_name, None)
174+
if val is not None and val is not False:
175+
declared.append(f"targeting.{field_name}")
176+
166177
# Extensions
167178
if caps.extensions_supported is not None:
168179
for item in caps.extensions_supported:

tests/test_capabilities.py

Lines changed: 58 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77

88
import pytest
99

10-
from adcp import ADCPClient
10+
from adcp import AccountReference, ADCPClient, SyncEventSourcesRequest
11+
from adcp.capabilities import FeatureResolver, validate_capabilities
1112
from adcp.exceptions import ADCPError, ADCPFeatureUnsupportedError
13+
from adcp.server.base import ADCPHandler
1214
from adcp.types.core import AgentConfig, Protocol, TaskResult, TaskStatus
1315
from adcp.types.generated_poc.core.media_buy_features import MediaBuyFeatures
1416
from adcp.types.generated_poc.protocol.get_adcp_capabilities_response import (
@@ -91,8 +93,6 @@ def _client_with_caps(
9193
**kwargs,
9294
) -> ADCPClient:
9395
"""Create a client and inject cached capabilities."""
94-
from adcp.capabilities import FeatureResolver
95-
9696
client = ADCPClient(_make_config(), **kwargs)
9797
client._capabilities = caps
9898
client._feature_resolver = FeatureResolver(caps)
@@ -193,6 +193,27 @@ def test_targeting_no_execution(self):
193193

194194
assert client.supports("targeting.geo_countries") is False
195195

196+
def test_targeting_nonexistent_field(self):
197+
"""Targeting field not in model_fields returns False."""
198+
caps = _make_capabilities(
199+
targeting={"geo_countries": True}
200+
)
201+
client = _client_with_caps(caps)
202+
203+
assert client.supports("targeting.nonexistent_field") is False
204+
assert client.supports("targeting.__class__") is False
205+
206+
def test_model_fields_guard_on_features(self):
207+
"""Pydantic internals are not treated as features."""
208+
caps = _make_capabilities(
209+
media_buy_features={"audience_targeting": True}
210+
)
211+
client = _client_with_caps(caps)
212+
213+
assert client.supports("model_dump") is False
214+
assert client.supports("model_fields") is False
215+
assert client.supports("__class__") is False
216+
196217
def test_unknown_feature_returns_false(self):
197218
caps = _make_capabilities(
198219
media_buy_features={"audience_targeting": True}
@@ -234,6 +255,12 @@ async def test_fetch_then_supports(self):
234255
class TestRequire:
235256
"""Tests for seller.require(*features)."""
236257

258+
def test_require_no_features_is_noop(self):
259+
"""require() with zero arguments does not raise."""
260+
caps = _make_capabilities()
261+
client = _client_with_caps(caps)
262+
client.require() # Should not raise
263+
237264
def test_require_all_present(self):
238265
caps = _make_capabilities(
239266
media_buy_features={"audience_targeting": True, "conversion_tracking": True}
@@ -299,7 +326,7 @@ def test_require_error_lists_declared_features(self):
299326
client.require("audience_targeting")
300327

301328
error_str = str(exc_info.value)
302-
assert "inline_creative_management" in error_str or "conversion_tracking" in error_str
329+
assert "inline_creative_management" in error_str and "conversion_tracking" in error_str
303330

304331
def test_require_includes_agent_context(self):
305332
caps = _make_capabilities()
@@ -402,6 +429,24 @@ async def test_refresh_capabilities_failure_raises(self):
402429
with pytest.raises(ADCPError, match="Failed to fetch capabilities"):
403430
await client.refresh_capabilities()
404431

432+
@pytest.mark.asyncio
433+
async def test_fetch_capabilities_success_true_data_none(self):
434+
"""success=True but data=None still raises."""
435+
client = ADCPClient(_make_config())
436+
437+
result = TaskResult(
438+
status=TaskStatus.COMPLETED,
439+
data=None,
440+
success=True,
441+
)
442+
with patch.object(
443+
client, "get_adcp_capabilities", new_callable=AsyncMock
444+
) as mock_get:
445+
mock_get.return_value = result
446+
447+
with pytest.raises(ADCPError, match="Failed to fetch capabilities"):
448+
await client.fetch_capabilities()
449+
405450

406451
# ========================================================================
407452
# Automatic validation (validate_features) tests
@@ -411,14 +456,20 @@ async def test_refresh_capabilities_failure_raises(self):
411456
class TestValidateFeatures:
412457
"""Tests for automatic feature validation on task calls."""
413458

459+
def test_validate_skips_unmapped_task(self):
460+
"""Tasks not in TASK_FEATURE_MAP are not validated."""
461+
caps = _make_capabilities(media_buy_features={})
462+
client = _client_with_caps(caps, validate_features=True)
463+
# get_products is not in TASK_FEATURE_MAP, should not raise
464+
client._validate_task_features("get_products")
465+
414466
@pytest.mark.asyncio
415467
async def test_sync_event_sources_requires_conversion_tracking(self):
416468
caps = _make_capabilities(
417469
media_buy_features={"audience_targeting": True}
418470
)
419471
client = _client_with_caps(caps, validate_features=True)
420472

421-
from adcp import AccountReference, SyncEventSourcesRequest
422473

423474
request = SyncEventSourcesRequest(
424475
account=AccountReference(account_id="acc1"),
@@ -448,7 +499,6 @@ async def test_validation_passes_when_feature_supported(self):
448499
)
449500
client = _client_with_caps(caps, validate_features=True)
450501

451-
from adcp import AccountReference, SyncEventSourcesRequest
452502

453503
mock_result = TaskResult(
454504
status=TaskStatus.COMPLETED,
@@ -470,7 +520,6 @@ async def test_validation_skipped_when_not_opted_in(self):
470520
)
471521
client = _client_with_caps(caps, validate_features=False)
472522

473-
from adcp import AccountReference, SyncEventSourcesRequest
474523

475524
mock_result = TaskResult(
476525
status=TaskStatus.COMPLETED,
@@ -490,7 +539,6 @@ async def test_validation_skipped_when_no_capabilities(self):
490539
"""When capabilities haven't been fetched, skip validation."""
491540
client = ADCPClient(_make_config(), validate_features=True)
492541

493-
from adcp import AccountReference, SyncEventSourcesRequest
494542

495543
mock_result = TaskResult(
496544
status=TaskStatus.COMPLETED,
@@ -560,8 +608,6 @@ class TestFeatureResolver:
560608
"""Tests for FeatureResolver used independently of ADCPClient."""
561609

562610
def test_resolver_supports(self):
563-
from adcp.capabilities import FeatureResolver
564-
565611
caps = _make_capabilities(
566612
protocols=["media_buy", "signals"],
567613
media_buy_features={"audience_targeting": True},
@@ -577,8 +623,6 @@ def test_resolver_supports(self):
577623
assert resolver.supports("conversion_tracking") is False
578624

579625
def test_resolver_require_raises(self):
580-
from adcp.capabilities import FeatureResolver
581-
582626
caps = _make_capabilities(media_buy_features={})
583627
resolver = FeatureResolver(caps)
584628

@@ -588,12 +632,11 @@ def test_resolver_require_raises(self):
588632
assert exc_info.value.agent_id == "test"
589633

590634
def test_resolver_get_declared_features(self):
591-
from adcp.capabilities import FeatureResolver
592-
593635
caps = _make_capabilities(
594636
protocols=["media_buy"],
595637
media_buy_features={"audience_targeting": True, "conversion_tracking": True},
596638
extensions=["scope3"],
639+
targeting={"geo_countries": True},
597640
)
598641
resolver = FeatureResolver(caps)
599642

@@ -602,10 +645,9 @@ def test_resolver_get_declared_features(self):
602645
assert "audience_targeting" in declared
603646
assert "conversion_tracking" in declared
604647
assert "ext:scope3" in declared
648+
assert "targeting.geo_countries" in declared
605649

606650
def test_resolver_capabilities_property(self):
607-
from adcp.capabilities import FeatureResolver
608-
609651
caps = _make_capabilities()
610652
resolver = FeatureResolver(caps)
611653
assert resolver.capabilities is caps
@@ -620,8 +662,6 @@ class TestValidateCapabilities:
620662
"""Tests for server-side validate_capabilities()."""
621663

622664
def test_warns_on_declared_but_unimplemented(self):
623-
from adcp.capabilities import validate_capabilities
624-
from adcp.server.base import ADCPHandler
625665

626666
class MyHandler(ADCPHandler):
627667
pass # Doesn't override anything
@@ -636,8 +676,6 @@ class MyHandler(ADCPHandler):
636676
assert any("log_event" in w or "sync_event_sources" in w for w in warnings)
637677

638678
def test_no_warnings_when_handler_overrides(self):
639-
from adcp.capabilities import validate_capabilities
640-
from adcp.server.base import ADCPHandler
641679

642680
class MyHandler(ADCPHandler):
643681
async def log_event(self, params, context=None):
@@ -654,8 +692,6 @@ async def sync_event_sources(self, params, context=None):
654692
assert len(warnings) == 0
655693

656694
def test_no_warnings_when_feature_not_declared(self):
657-
from adcp.capabilities import validate_capabilities
658-
from adcp.server.base import ADCPHandler
659695

660696
class MyHandler(ADCPHandler):
661697
pass
@@ -669,8 +705,6 @@ class MyHandler(ADCPHandler):
669705

670706
def test_warns_on_partial_implementation(self):
671707
"""If only some handler methods for a feature are overridden."""
672-
from adcp.capabilities import validate_capabilities
673-
from adcp.server.base import ADCPHandler
674708

675709
class MyHandler(ADCPHandler):
676710
async def log_event(self, params, context=None):
@@ -688,8 +722,6 @@ async def log_event(self, params, context=None):
688722

689723
def test_no_warnings_when_mixin_overrides(self):
690724
"""Overrides inherited from an intermediate class are detected."""
691-
from adcp.capabilities import validate_capabilities
692-
from adcp.server.base import ADCPHandler
693725

694726
class ConversionMixin(ADCPHandler):
695727
async def log_event(self, params, context=None):
@@ -710,8 +742,6 @@ class MyHandler(ConversionMixin):
710742

711743
def test_multiple_handler_methods_warned(self):
712744
"""All handler methods for a declared feature produce warnings when unoverridden."""
713-
from adcp.capabilities import validate_capabilities
714-
from adcp.server.base import ADCPHandler
715745

716746
class MyHandler(ADCPHandler):
717747
pass

0 commit comments

Comments
 (0)