Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 16 additions & 30 deletions core/src/main/java/com/google/adk/tools/mcp/McpAsyncToolset.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import com.google.adk.tools.BaseTool;
import com.google.adk.tools.BaseToolset;
import com.google.adk.tools.NamedToolPredicate;
import com.google.adk.tools.ToolPredicate;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import io.modelcontextprotocol.client.McpAsyncClient;
Expand All @@ -32,8 +34,8 @@
import java.time.Duration;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;
Expand All @@ -59,14 +61,14 @@ public class McpAsyncToolset implements BaseToolset {

private final McpSessionManager mcpSessionManager;
private final ObjectMapper objectMapper;
private final Optional<Object> toolFilter;
private final @Nullable Object toolFilter;
private final AtomicReference<Mono<List<McpAsyncTool>>> mcpTools = new AtomicReference<>();

/** Builder for McpAsyncToolset */
public static class Builder {
private Object connectionParams = null;
private ObjectMapper objectMapper = null;
private Optional<Object> toolFilter = null;
private @Nullable Object toolFilter = null;

@CanIgnoreReturnValue
public Builder connectionParams(ServerParameters connectionParams) {
Expand All @@ -87,14 +89,14 @@ public Builder objectMapper(ObjectMapper objectMapper) {
}

@CanIgnoreReturnValue
public Builder toolFilter(Optional<Object> toolFilter) {
this.toolFilter = toolFilter;
public Builder toolFilter(List<String> toolNames) {
this.toolFilter = new NamedToolPredicate(Preconditions.checkNotNull(toolNames));
return this;
}

@CanIgnoreReturnValue
public Builder toolFilter(List<String> toolNames) {
this.toolFilter = Optional.of(new NamedToolPredicate(toolNames));
public Builder toolFilter(@Nullable ToolPredicate toolPredicate) {
this.toolFilter = toolPredicate;
return this;
}

Expand All @@ -118,12 +120,12 @@ public McpAsyncToolset build() {
*
* @param connectionParams The SSE connection parameters to the MCP server.
* @param objectMapper An ObjectMapper instance for parsing schemas.
* @param toolFilter An Optional containing either a ToolPredicate or a List of tool names.
* @param toolFilter Either a ToolPredicate or a List of tool names.
*/
public McpAsyncToolset(
McpAsyncToolset(
SseServerParameters connectionParams,
ObjectMapper objectMapper,
Optional<Object> toolFilter) {
@Nullable Object toolFilter) {
Objects.requireNonNull(connectionParams);
Objects.requireNonNull(objectMapper);
this.objectMapper = objectMapper;
Expand All @@ -136,41 +138,25 @@ public McpAsyncToolset(
*
* @param connectionParams The local server connection parameters to the MCP server.
* @param objectMapper An ObjectMapper instance for parsing schemas.
* @param toolFilter An Optional containing either a ToolPredicate or a List of tool names.
* @param toolFilter Either a ToolPredicate or a List of tool names or null.
*/
public McpAsyncToolset(
ServerParameters connectionParams, ObjectMapper objectMapper, Optional<Object> toolFilter) {
McpAsyncToolset(
ServerParameters connectionParams, ObjectMapper objectMapper, @Nullable Object toolFilter) {
Objects.requireNonNull(connectionParams);
Objects.requireNonNull(objectMapper);
this.objectMapper = objectMapper;
this.mcpSessionManager = new McpSessionManager(connectionParams);
this.toolFilter = toolFilter;
}

/**
* Initializes the McpAsyncToolset with a provided McpSessionManager.
*
* @param mcpSessionManager The session manager for MCP connections.
* @param objectMapper An ObjectMapper instance for parsing schemas.
* @param toolFilter An Optional containing either a ToolPredicate or a List of tool names.
*/
public McpAsyncToolset(
McpSessionManager mcpSessionManager, ObjectMapper objectMapper, Optional<Object> toolFilter) {
Objects.requireNonNull(mcpSessionManager);
Objects.requireNonNull(objectMapper);
this.objectMapper = objectMapper;
this.mcpSessionManager = mcpSessionManager;
this.toolFilter = toolFilter;
}

@Override
public Flowable<BaseTool> getTools(ReadonlyContext readonlyContext) {
return Maybe.defer(() -> Maybe.fromCompletionStage(this.initAndGetTools().toFuture()))
.defaultIfEmpty(ImmutableList.of())
.map(
tools ->
tools.stream()
.filter(tool -> isToolSelected(tool, toolFilter.orElse(null), readonlyContext))
.filter(tool -> isToolSelected(tool, toolFilter, readonlyContext))
.toList())
.onErrorResumeNext(
err -> {
Expand Down