|
| 1 | +"""Feature capability resolution for AdCP. |
| 2 | +
|
| 3 | +Shared logic for resolving feature support from a capabilities response. |
| 4 | +Used by both the client (buyer-side validation) and server (seller-side validation). |
| 5 | +""" |
| 6 | + |
| 7 | +from __future__ import annotations |
| 8 | + |
| 9 | +# GetAdcpCapabilitiesResponse is under TYPE_CHECKING to avoid a circular import |
| 10 | +# (adcp.types imports from generated_poc which imports from adcp.types.base). |
| 11 | +# This is safe because `from __future__ import annotations` makes all annotations |
| 12 | +# strings that are never evaluated at runtime. |
| 13 | +from typing import TYPE_CHECKING, Any |
| 14 | + |
| 15 | +from adcp.exceptions import ADCPFeatureUnsupportedError |
| 16 | + |
| 17 | +if TYPE_CHECKING: |
| 18 | + from adcp.types.generated_poc.protocol.get_adcp_capabilities_response import ( |
| 19 | + GetAdcpCapabilitiesResponse, |
| 20 | + ) |
| 21 | + |
| 22 | +# Mapping from AdCP task names to the media_buy.features flag they require. |
| 23 | +# Only includes tasks that exist on ADCPClient and ADCPHandler. |
| 24 | +# Other features (audience_targeting, catalog_management, etc.) will be added |
| 25 | +# here when their corresponding task methods are implemented. |
| 26 | +TASK_FEATURE_MAP: dict[str, str] = { |
| 27 | + "sync_event_sources": "conversion_tracking", |
| 28 | + "log_event": "conversion_tracking", |
| 29 | +} |
| 30 | + |
| 31 | +# Derived: feature -> list of handler methods that implement it. |
| 32 | +# Used by validate_capabilities() to check that sellers implement what they declare. |
| 33 | +FEATURE_HANDLER_MAP: dict[str, list[str]] = {} |
| 34 | +for _task, _feature in TASK_FEATURE_MAP.items(): |
| 35 | + FEATURE_HANDLER_MAP.setdefault(_feature, []).append(_task) |
| 36 | + |
| 37 | + |
| 38 | +class FeatureResolver: |
| 39 | + """Resolves feature support from a GetAdcpCapabilitiesResponse. |
| 40 | +
|
| 41 | + Supports multiple feature namespaces: |
| 42 | +
|
| 43 | + - Protocol support: ``"media_buy"`` checks ``supported_protocols`` |
| 44 | + - Extension support: ``"ext:scope3"`` checks ``extensions_supported`` |
| 45 | + - Targeting: ``"targeting.geo_countries"`` checks |
| 46 | + ``media_buy.execution.targeting`` |
| 47 | + - Media buy features: ``"audience_targeting"`` checks |
| 48 | + ``media_buy.features`` |
| 49 | + - Signals features: ``"catalog_signals"`` checks |
| 50 | + ``signals.features`` |
| 51 | + """ |
| 52 | + |
| 53 | + def __init__(self, capabilities: GetAdcpCapabilitiesResponse) -> None: |
| 54 | + self._caps = capabilities |
| 55 | + |
| 56 | + # Pre-compute the set of valid protocol names so supports() doesn't |
| 57 | + # need a runtime import on every call. |
| 58 | + from adcp.types.generated_poc.protocol.get_adcp_capabilities_response import ( |
| 59 | + SupportedProtocol, |
| 60 | + ) |
| 61 | + |
| 62 | + self._valid_protocols = {p.value for p in SupportedProtocol} |
| 63 | + self._declared_protocols = {p.value for p in capabilities.supported_protocols} |
| 64 | + |
| 65 | + @property |
| 66 | + def capabilities(self) -> GetAdcpCapabilitiesResponse: |
| 67 | + return self._caps |
| 68 | + |
| 69 | + def supports(self, feature: str) -> bool: |
| 70 | + """Check if a feature is supported.""" |
| 71 | + caps = self._caps |
| 72 | + |
| 73 | + # Extension check: "ext:scope3" |
| 74 | + if feature.startswith("ext:"): |
| 75 | + ext_name = feature[4:] |
| 76 | + if caps.extensions_supported is None: |
| 77 | + return False |
| 78 | + return any(item.root == ext_name for item in caps.extensions_supported) |
| 79 | + |
| 80 | + # Targeting check: "targeting.geo_countries" |
| 81 | + if feature.startswith("targeting."): |
| 82 | + attr_name = feature[len("targeting."):] |
| 83 | + if caps.media_buy is None or caps.media_buy.execution is None: |
| 84 | + return False |
| 85 | + targeting = caps.media_buy.execution.targeting |
| 86 | + if targeting is None: |
| 87 | + return False |
| 88 | + if attr_name not in type(targeting).model_fields: |
| 89 | + return False |
| 90 | + val = getattr(targeting, attr_name, None) |
| 91 | + # For bool fields, check truthiness. For object fields (like geo_metros), |
| 92 | + # presence means supported. |
| 93 | + return val is not None and val is not False |
| 94 | + |
| 95 | + # Protocol check: if the string is a known protocol name, resolve it |
| 96 | + # against supported_protocols and stop — don't fall through to features. |
| 97 | + if feature in self._declared_protocols: |
| 98 | + return True |
| 99 | + if feature in self._valid_protocols: |
| 100 | + return False |
| 101 | + |
| 102 | + # Media buy features check |
| 103 | + if caps.media_buy is not None and caps.media_buy.features is not None: |
| 104 | + if feature in type(caps.media_buy.features).model_fields: |
| 105 | + val = getattr(caps.media_buy.features, feature, None) |
| 106 | + if val is True: |
| 107 | + return True |
| 108 | + |
| 109 | + # Signals features check |
| 110 | + if caps.signals is not None and caps.signals.features is not None: |
| 111 | + if feature in type(caps.signals.features).model_fields: |
| 112 | + val = getattr(caps.signals.features, feature, None) |
| 113 | + if val is True: |
| 114 | + return True |
| 115 | + |
| 116 | + return False |
| 117 | + |
| 118 | + def require( |
| 119 | + self, |
| 120 | + *features: str, |
| 121 | + agent_id: str | None = None, |
| 122 | + agent_uri: str | None = None, |
| 123 | + ) -> None: |
| 124 | + """Assert that all listed features are supported. |
| 125 | +
|
| 126 | + Args: |
| 127 | + *features: Feature identifiers to require. |
| 128 | + agent_id: Optional agent ID for error context. |
| 129 | + agent_uri: Optional agent URI for error context. |
| 130 | +
|
| 131 | + Raises: |
| 132 | + ADCPFeatureUnsupportedError: If any features are not supported. |
| 133 | + """ |
| 134 | + unsupported = [f for f in features if not self.supports(f)] |
| 135 | + if not unsupported: |
| 136 | + return |
| 137 | + |
| 138 | + declared = self.get_declared_features() |
| 139 | + |
| 140 | + raise ADCPFeatureUnsupportedError( |
| 141 | + unsupported_features=unsupported, |
| 142 | + declared_features=declared, |
| 143 | + agent_id=agent_id, |
| 144 | + agent_uri=agent_uri, |
| 145 | + ) |
| 146 | + |
| 147 | + def get_declared_features(self) -> list[str]: |
| 148 | + """Collect all features the response declares as supported.""" |
| 149 | + caps = self._caps |
| 150 | + declared: list[str] = [] |
| 151 | + |
| 152 | + # Supported protocols |
| 153 | + for p in caps.supported_protocols: |
| 154 | + declared.append(p.value) |
| 155 | + |
| 156 | + # Media buy features |
| 157 | + if caps.media_buy is not None and caps.media_buy.features is not None: |
| 158 | + for field_name in type(caps.media_buy.features).model_fields: |
| 159 | + if getattr(caps.media_buy.features, field_name, None) is True: |
| 160 | + declared.append(field_name) |
| 161 | + |
| 162 | + # Signals features |
| 163 | + if caps.signals is not None and caps.signals.features is not None: |
| 164 | + for field_name in type(caps.signals.features).model_fields: |
| 165 | + if getattr(caps.signals.features, field_name, None) is True: |
| 166 | + declared.append(field_name) |
| 167 | + |
| 168 | + # Targeting features |
| 169 | + if caps.media_buy is not None and caps.media_buy.execution is not None: |
| 170 | + targeting = caps.media_buy.execution.targeting |
| 171 | + if targeting is not None: |
| 172 | + for field_name in type(targeting).model_fields: |
| 173 | + val = getattr(targeting, field_name, None) |
| 174 | + if val is not None and val is not False: |
| 175 | + declared.append(f"targeting.{field_name}") |
| 176 | + |
| 177 | + # Extensions |
| 178 | + if caps.extensions_supported is not None: |
| 179 | + for item in caps.extensions_supported: |
| 180 | + declared.append(f"ext:{item.root}") |
| 181 | + |
| 182 | + return declared |
| 183 | + |
| 184 | + |
| 185 | +def validate_capabilities( |
| 186 | + handler: Any, |
| 187 | + capabilities: GetAdcpCapabilitiesResponse, |
| 188 | +) -> list[str]: |
| 189 | + """Check that a handler implements the methods required by its declared features. |
| 190 | +
|
| 191 | + Compares the features declared in a capabilities response against the handler's |
| 192 | + method implementations. Returns warnings for features that are declared but |
| 193 | + whose corresponding handler methods are not overridden from the base class. |
| 194 | +
|
| 195 | + This is a development-time check — call it at startup to catch misconfigurations. |
| 196 | +
|
| 197 | + Args: |
| 198 | + handler: An ADCPHandler instance (or any object with handler methods). |
| 199 | + capabilities: The capabilities response the handler will serve. |
| 200 | +
|
| 201 | + Returns: |
| 202 | + List of warning strings. Empty if everything is consistent. |
| 203 | + """ |
| 204 | + # Late import to avoid circular dependency: server.base imports from adcp.types |
| 205 | + # which may transitively import from this module. |
| 206 | + from adcp.server.base import ADCPHandler |
| 207 | + |
| 208 | + resolver = FeatureResolver(capabilities) |
| 209 | + warnings: list[str] = [] |
| 210 | + |
| 211 | + for feature, handler_methods in FEATURE_HANDLER_MAP.items(): |
| 212 | + if not resolver.supports(feature): |
| 213 | + continue |
| 214 | + |
| 215 | + for method_name in handler_methods: |
| 216 | + if not hasattr(handler, method_name): |
| 217 | + warnings.append( |
| 218 | + f"Feature '{feature}' is declared but handler has no " |
| 219 | + f"'{method_name}' method" |
| 220 | + ) |
| 221 | + continue |
| 222 | + |
| 223 | + # Walk MRO to check if any class between the leaf and ADCPHandler |
| 224 | + # overrides the method (handles mixin / intermediate-class patterns). |
| 225 | + if isinstance(handler, ADCPHandler): |
| 226 | + overridden = any( |
| 227 | + method_name in cls.__dict__ |
| 228 | + for cls in type(handler).__mro__ |
| 229 | + if cls is not ADCPHandler and not issubclass(ADCPHandler, cls) |
| 230 | + ) |
| 231 | + if not overridden: |
| 232 | + warnings.append( |
| 233 | + f"Feature '{feature}' is declared but '{method_name}' " |
| 234 | + f"is not overridden from ADCPHandler" |
| 235 | + ) |
| 236 | + |
| 237 | + return warnings |
0 commit comments