|
| 1 | +package search |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "fmt" |
| 6 | + "io" |
| 7 | + "sort" |
| 8 | + "sync" |
| 9 | +) |
| 10 | + |
| 11 | +// SourceType represents the source to search |
| 12 | +type SourceType string |
| 13 | + |
| 14 | +const ( |
| 15 | + SourceAll SourceType = "all" |
| 16 | + SourceDockerHub SourceType = "dockerhub" |
| 17 | + SourceHuggingFace SourceType = "huggingface" |
| 18 | +) |
| 19 | + |
| 20 | +// AggregatedClient searches multiple sources and merges results |
| 21 | +type AggregatedClient struct { |
| 22 | + clients []SearchClient |
| 23 | + errOut io.Writer |
| 24 | +} |
| 25 | + |
| 26 | +// NewAggregatedClient creates a client that searches the specified sources |
| 27 | +func NewAggregatedClient(source SourceType, errOut io.Writer) *AggregatedClient { |
| 28 | + var clients []SearchClient |
| 29 | + |
| 30 | + switch source { |
| 31 | + case SourceDockerHub: |
| 32 | + clients = []SearchClient{NewDockerHubClient()} |
| 33 | + case SourceHuggingFace: |
| 34 | + clients = []SearchClient{NewHuggingFaceClient()} |
| 35 | + case SourceAll: |
| 36 | + clients = []SearchClient{ |
| 37 | + NewDockerHubClient(), |
| 38 | + NewHuggingFaceClient(), |
| 39 | + } |
| 40 | + default: // This handles any unexpected values |
| 41 | + clients = []SearchClient{ |
| 42 | + NewDockerHubClient(), |
| 43 | + NewHuggingFaceClient(), |
| 44 | + } |
| 45 | + } |
| 46 | + |
| 47 | + return &AggregatedClient{ |
| 48 | + clients: clients, |
| 49 | + errOut: errOut, |
| 50 | + } |
| 51 | +} |
| 52 | + |
| 53 | +// searchResult holds results from a single source along with any error |
| 54 | +type searchResult struct { |
| 55 | + results []SearchResult |
| 56 | + err error |
| 57 | + source string |
| 58 | +} |
| 59 | + |
| 60 | +// Search searches all configured sources and merges results |
| 61 | +func (c *AggregatedClient) Search(ctx context.Context, opts SearchOptions) ([]SearchResult, error) { |
| 62 | + // Search all sources concurrently |
| 63 | + resultsChan := make(chan searchResult, len(c.clients)) |
| 64 | + var wg sync.WaitGroup |
| 65 | + |
| 66 | + for _, client := range c.clients { |
| 67 | + wg.Add(1) |
| 68 | + go func(client SearchClient) { |
| 69 | + defer wg.Done() |
| 70 | + results, err := client.Search(ctx, opts) |
| 71 | + resultsChan <- searchResult{ |
| 72 | + results: results, |
| 73 | + err: err, |
| 74 | + source: client.Name(), |
| 75 | + } |
| 76 | + }(client) |
| 77 | + } |
| 78 | + |
| 79 | + // Wait for all searches to complete |
| 80 | + go func() { |
| 81 | + wg.Wait() |
| 82 | + close(resultsChan) |
| 83 | + }() |
| 84 | + |
| 85 | + // Collect results |
| 86 | + var allResults []SearchResult |
| 87 | + var errors []error |
| 88 | + |
| 89 | + for result := range resultsChan { |
| 90 | + if result.err != nil { |
| 91 | + errors = append(errors, fmt.Errorf("%s: %w", result.source, result.err)) |
| 92 | + if c.errOut != nil { |
| 93 | + fmt.Fprintf(c.errOut, "Warning: failed to search %s: %v\n", result.source, result.err) |
| 94 | + } |
| 95 | + continue |
| 96 | + } |
| 97 | + allResults = append(allResults, result.results...) |
| 98 | + } |
| 99 | + |
| 100 | + // If all sources failed, return the collected errors |
| 101 | + if len(allResults) == 0 && len(errors) > 0 { |
| 102 | + return nil, fmt.Errorf("all search sources failed: %v", errors) |
| 103 | + } |
| 104 | + |
| 105 | + // Sort by source (Docker Hub first), then by downloads within each source |
| 106 | + sort.Slice(allResults, func(i, j int) bool { |
| 107 | + // Docker Hub comes before HuggingFace |
| 108 | + if allResults[i].Source != allResults[j].Source { |
| 109 | + return allResults[i].Source == DockerHubSourceName |
| 110 | + } |
| 111 | + // Within same source, sort by downloads (popularity) |
| 112 | + return allResults[i].Downloads > allResults[j].Downloads |
| 113 | + }) |
| 114 | + |
| 115 | + // Limit total results if needed |
| 116 | + if opts.Limit > 0 && len(allResults) > opts.Limit { |
| 117 | + allResults = allResults[:opts.Limit] |
| 118 | + } |
| 119 | + |
| 120 | + return allResults, nil |
| 121 | +} |
| 122 | + |
| 123 | +// ParseSource parses a source string into a SourceType |
| 124 | +func ParseSource(s string) (SourceType, error) { |
| 125 | + switch s { |
| 126 | + case "all", "": |
| 127 | + return SourceAll, nil |
| 128 | + case "dockerhub", "docker", "hub": |
| 129 | + return SourceDockerHub, nil |
| 130 | + case "huggingface", "hf": |
| 131 | + return SourceHuggingFace, nil |
| 132 | + default: |
| 133 | + return "", fmt.Errorf("unknown source %q: valid options are 'all', 'dockerhub', 'docker', 'hub', 'huggingface', 'hf'", s) |
| 134 | + } |
| 135 | +} |
0 commit comments