diff --git a/cmd/cli/commands/run.go b/cmd/cli/commands/run.go index 58e92913..f3d7beb4 100644 --- a/cmd/cli/commands/run.go +++ b/cmd/cli/commands/run.go @@ -809,11 +809,34 @@ func newRunCmd() *cobra.Command { return nil } - _, err := desktopClient.Inspect(model, false) - if err != nil { - if !errors.Is(err, desktop.ErrNotFound) { - return handleClientError(err, "Failed to inspect model") + modelInfo, err := desktopClient.Inspect(model, false) + modelFoundLocally := err == nil + if err != nil && !errors.Is(err, desktop.ErrNotFound) { + return handleClientError(err, "Failed to inspect model") + } + + if !modelFoundLocally { + remoteInfo, remoteErr := desktopClient.Inspect(model, true) + if remoteErr == nil { + modelInfo = remoteInfo + } + } + + backend := "" + if modelInfo.ID != "" { + backend, _ = GetRequiredBackendFromModelInfo(&modelInfo) + } + + if backend != "" { + if err := EnsureBackendAvailable(backend, cmd); err != nil { + if err.Error() == "backend installation cancelled" { + return nil + } + return err } + } + + if !modelFoundLocally { cmd.Println("Unable to find model '" + model + "' locally. Pulling from the server.") if err := pullModel(cmd, desktopClient, model); err != nil { return err diff --git a/cmd/cli/commands/utils.go b/cmd/cli/commands/utils.go index c77f472b..27d50482 100644 --- a/cmd/cli/commands/utils.go +++ b/cmd/cli/commands/utils.go @@ -1,7 +1,9 @@ package commands import ( + "bufio" "bytes" + "encoding/json" "errors" "fmt" "io" @@ -12,7 +14,10 @@ import ( "github.com/docker/model-runner/cmd/cli/desktop" "github.com/docker/model-runner/cmd/cli/pkg/standalone" "github.com/docker/model-runner/pkg/distribution/oci/reference" + "github.com/docker/model-runner/pkg/distribution/types" + "github.com/docker/model-runner/pkg/inference/backends/llamacpp" "github.com/docker/model-runner/pkg/inference/backends/vllm" + dmrm "github.com/docker/model-runner/pkg/inference/models" "github.com/moby/term" "github.com/olekukonko/tablewriter" "github.com/olekukonko/tablewriter/renderer" @@ -270,3 +275,114 @@ func newTable(w io.Writer) *tablewriter.Table { }), ) } + +func CheckBackendInstalled(backend string) (bool, error) { + status := desktopClient.Status() + if status.Error != nil { + return false, fmt.Errorf("failed to get backend status: %w", status.Error) + } + + var backendStatus map[string]string + if err := json.Unmarshal(status.Status, &backendStatus); err != nil { + return false, fmt.Errorf("failed to parse backend status: %w", err) + } + + backendState, exists := backendStatus[backend] + if !exists { + return false, nil + } + + state := strings.TrimSpace(strings.ToLower(backendState)) + if strings.HasPrefix(state, "not ") || strings.HasPrefix(state, "error") { + return false, nil + } + + return strings.HasPrefix(state, "installed") || strings.HasPrefix(state, "running"), nil +} + +func PromptInstallBackend(backend string, cmd *cobra.Command) (bool, error) { + fmt.Fprintf(cmd.OutOrStdout(), "Backend %q is not installed. Download and install it now? [Y/n]: ", backend) + + reader := bufio.NewReader(os.Stdin) + input, err := reader.ReadString('\n') + if err != nil { + return false, fmt.Errorf("failed to read input: %w", err) + } + + input = strings.TrimSpace(strings.ToLower(input)) + return input == "" || input == "y" || input == "yes", nil +} + +func InstallBackend(backend string, cmd *cobra.Command) error { + installCmd := newInstallRunner() + installCmd.SetArgs([]string{"--backend", backend}) + + if err := installCmd.Execute(); err != nil { + return fmt.Errorf("failed to install backend %s: %w", backend, err) + } + + return nil +} + +func EnsureBackendAvailable(backend string, cmd *cobra.Command) error { + installed, err := CheckBackendInstalled(backend) + if err != nil { + return err + } + + if installed { + return nil + } + + confirm, err := PromptInstallBackend(backend, cmd) + if err != nil { + return err + } + + if !confirm { + cmd.Printf("Run 'docker model install-runner --backend %s' to install it manually.\n", backend) + return fmt.Errorf("backend installation cancelled") + } + + if err := InstallBackend(backend, cmd); err != nil { + return err + } + + installed, err = CheckBackendInstalled(backend) + if err != nil { + return err + } + if !installed { + return fmt.Errorf("backend %q is still not installed; run 'docker model install-runner --backend %s'", backend, backend) + } + + cmd.Printf("Backend %q installed successfully.\n", backend) + return nil +} + +func GetRequiredBackend(model string) (string, error) { + modelInfo, err := desktopClient.Inspect(model, false) + if err != nil { + return "", err + } + + return GetRequiredBackendFromModelInfo(&modelInfo) +} + +func GetRequiredBackendFromModelInfo(modelInfo *dmrm.Model) (string, error) { + config, ok := modelInfo.Config.(*types.Config) + if !ok { + return llamacpp.Name, nil + } + + switch config.Format { + case types.FormatSafetensors: + return vllm.Name, nil + case types.FormatGGUF: + return llamacpp.Name, nil + case types.FormatDiffusers: + return "diffusers", nil + default: + return llamacpp.Name, nil + } +}