Skip to content
Draft
29 changes: 23 additions & 6 deletions dnastack/cli/commands/explorer/questions/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ def describe_question(question_id: str, output: str, context: Optional[str], end
arg_names=['--output-file'],
help='Output file path for results'
),
ArgumentSpec(
name='local_federated',
arg_names=['--local-federated'],
help='Query collections directly via local federation instead of using server-side federation',
type=bool,
default=False
),
DATA_OUTPUT_ARG,
CONTEXT_ARG,
SINGLE_ENDPOINT_ID_ARG,
Expand All @@ -112,13 +119,15 @@ def ask_question(
args: tuple,
collections: Optional[JsonLike],
output_file: Optional[str],
local_federated: bool,
output: str,
context: Optional[str],
endpoint_id: Optional[str]
):
"""Ask a federated question with the provided parameters"""
trace = Span()
client = get_explorer_client(context=context, endpoint_id=endpoint_id, trace=trace)


# Parse collections if provided
if collections:
Expand Down Expand Up @@ -162,12 +171,20 @@ def ask_question(
collection_ids = [col.id for col in question.collections]

# Execute the question
results_iter = client.ask_federated_question(
question_id=question_name,
inputs=inputs,
collections=collection_ids,
trace=trace
)
if local_federated:
results_iter = client.ask_question_local_federated(
federated_question_id=question_name,
inputs=inputs,
collections=collection_ids,
trace=trace
)
else:
results_iter = client.ask_federated_question(
question_id=question_name,
inputs=inputs,
collections=collection_ids,
trace=trace
)

# Collect results
results = list(results_iter)
Expand Down
240 changes: 238 additions & 2 deletions dnastack/client/explorer/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import List, Optional, Dict, Any, TYPE_CHECKING
from concurrent.futures import ThreadPoolExecutor, as_completed
import time

if TYPE_CHECKING:
from dnastack.client.explorer.models import FederatedQuestion
Expand All @@ -10,7 +12,8 @@
from dnastack.client.explorer.models import (
FederatedQuestion,
FederatedQuestionListResponse,
FederatedQuestionQueryRequest
FederatedQuestionQueryRequest,
QuestionCollection
)
from dnastack.client.result_iterator import ResultLoader, InactiveLoaderError, ResultIterator
from dnastack.client.service_registry.models import ServiceType
Expand Down Expand Up @@ -136,6 +139,55 @@ def ask_federated_question(
)
)

def ask_question_local_federated(
self,
federated_question_id: str,
inputs: Dict[str, str],
collections: Optional[List[str]] = None,
trace: Optional[Span] = None
) -> 'ResultIterator[Dict[str, Any]]':
"""
Query collections directly via local federation instead of server-side federation.

Args:
federated_question_id: The ID of the federated question to ask
inputs: Dictionary of parameter name -> value mappings
collections: Optional list of collection IDs to query. If None, all collections are used.
trace: Optional tracing span

Returns:
ResultIterator[Dict[str, Any]]: Iterator over aggregated query results in federated format
"""
# Get federated question metadata to obtain per-collection question IDs
question = self.describe_federated_question(federated_question_id, trace=trace)

# Filter collections if specified
if collections is not None:
# Create a map of collection ID to QuestionCollection for filtering
collection_map = {col.id: col for col in question.collections}
target_collections = [collection_map[cid] for cid in collections if cid in collection_map]

# Check for invalid collection IDs
invalid_ids = [cid for cid in collections if cid not in collection_map]
if invalid_ids:
raise ClientError(
response=None,
trace_context=trace,
message=f"Invalid collection IDs for question '{federated_question_id}': {', '.join(invalid_ids)}"
)
else:
target_collections = question.collections

# Create the result loader for local federation
return ResultIterator(
LocalFederatedQuestionQueryResultLoader(
explorer_client=self,
collections=target_collections,
inputs=inputs,
trace=trace
)
)


class FederatedQuestionListResultLoader(ResultLoader):
"""
Expand Down Expand Up @@ -248,4 +300,188 @@ def load(self) -> List[Dict[str, Any]]:
raise ClientError(e.response, e.trace, "Invalid question parameters")
else:

raise ClientError(e.response, e.trace, "Failed to execute federated question")
raise ClientError(e.response, e.trace, "Failed to execute federated question")


class LocalFederatedQuestionQueryResultLoader(ResultLoader):
"""
Result loader for local federation queries that queries each collection directly.
"""

def __init__(
self,
explorer_client: 'ExplorerClient',
collections: List[QuestionCollection],
inputs: Dict[str, str],
trace: Optional[Span] = None
):
self.__explorer_client = explorer_client
self.__collections = collections
self.__inputs = inputs
self.__trace = trace
self.__loaded = False

def has_more(self) -> bool:
return not self.__loaded

def load(self) -> List[Dict[str, Any]]:
if self.__loaded:
raise InactiveLoaderError("LocalFederatedQuestionQueryResultLoader")

# Execute parallel queries to each collection
with ThreadPoolExecutor() as executor:
# Submit all queries
future_to_collection = {
executor.submit(
self._query_single_collection,
collection
): collection
for collection in self.__collections
}

# Collect results
results = []
for future in as_completed(future_to_collection):
result = future.result()
results.append(result)

# Return results directly as a list to match federated format
self.__loaded = True
return results # Return as list to match federated endpoint format

def _query_single_collection(self, collection: QuestionCollection) -> Dict[str, Any]:
"""
Query a single collection and return the result in federated format.
Handles Data Connect pagination by following next_page_url links.
"""
start_time = time.time()

# Build the collection-specific endpoint URL
# Note: explorer URL already ends with /api/, so we don't need to add it again
initial_url = urljoin(
self.__explorer_client.url,
f"collections/{collection.slug}/questions/{collection.question_id}/query"
)

try:
# Collect all data across all pages
all_data = []
data_model = None
current_url = None
visited_urls = []

with self.__explorer_client._session as session:
# First request - POST with params to initiate query
response = session.post(
initial_url,
json={"params": self.__inputs},
trace_context=self.__trace
)
visited_urls.append(initial_url)

while True:
# Parse the Data Connect response
table_data = response.json()

# Capture data model from first response
if data_model is None and 'data_model' in table_data:
data_model = table_data['data_model']

# Add data from this page
if 'data' in table_data and isinstance(table_data['data'], list):
# Add collection_name to each item
for item in table_data['data']:
item['collection_name'] = collection.name
all_data.extend(table_data['data'])

# Check for next page
pagination = table_data.get('pagination')
if pagination and pagination.get('next_page_url'):
current_url = pagination['next_page_url']
# Handle relative URLs
if current_url and not current_url.startswith(('http://', 'https://')):
current_url = urljoin(visited_urls[-1], current_url)

# Prevent infinite loops
if current_url in visited_urls:
break

# Follow pagination with GET request
response = session.get(
current_url,
trace_context=self.__trace
)
visited_urls.append(current_url)
else:
# No more pages
break

# Build final aggregated response
aggregated_table_data = {
"data": all_data,
"data_model": data_model,
"pagination": None # No pagination in aggregated result
}

# Return in federated format
return {
"collectionId": collection.id,
"collectionSlug": collection.slug,
"results": aggregated_table_data,
"error": None,
"failureInfo": None
}

except HttpError as e:
# Calculate response time
response_time_ms = int((time.time() - start_time) * 1000)

# Determine failure reason
status_code = e.response.status_code if e.response else None
if status_code == 401:
reason = "UNAUTHORIZED"
message = f"Authentication required for collection {collection.name}"
elif status_code == 403:
reason = "FORBIDDEN"
message = f"Access denied to collection {collection.name}"
elif status_code == 404:
reason = "NOT_FOUND"
message = f"Question not found in collection {collection.name}"
elif status_code == 400:
reason = "BAD_REQUEST"
message = f"Invalid parameters for collection {collection.name}"
elif status_code and status_code >= 500:
reason = "SERVER_ERROR"
message = f"Server error for collection {collection.name}"
else:
reason = "UNKNOWN"
message = str(e)

# Return error in federated format
return {
"collectionId": collection.id,
"collectionSlug": collection.slug,
"results": None,
"error": message,
"failureInfo": {
"reason": reason,
"message": message,
"responseTimeMs": response_time_ms
}
}

except Exception as e:
# Handle non-HTTP errors
response_time_ms = int((time.time() - start_time) * 1000)

return {
"collectionId": collection.id,
"collectionSlug": collection.slug,
"results": None,
"error": str(e),
"failureInfo": {
"reason": "CLIENT_ERROR",
"message": str(e),
"responseTimeMs": response_time_ms
}
}
Loading
Loading