Skip to content
29 changes: 28 additions & 1 deletion backend/go/acestep-cpp/goacestepcpp.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"fmt"
"os"
"path/filepath"
"strings"

"github.com/mudler/LocalAI/pkg/grpc/base"
Expand All @@ -11,7 +12,7 @@ import (

var (
CppLoadModel func(lmModelPath, textEncoderPath, ditModelPath, vaeModelPath string) int
CppGenerateMusic func(caption, lyrics string, bpm int, keyscale, timesignature string, duration, temperature float32, instrumental bool, seed int, dst string, threads int) int
CppGenerateMusic func(caption, lyrics string, bpm int, keyscale, timesignature string, duration, temperature float32, instrumental bool, seed int, dst string, threads int)
Comment thread
mudler marked this conversation as resolved.
Outdated
)

type AceStepCpp struct {
Expand All @@ -22,6 +23,9 @@ func (a *AceStepCpp) Load(opts *pb.ModelOptions) error {
// ModelFile is the LM model path
lmModel := opts.ModelFile

// Get the base directory from ModelFile for resolving relative paths
baseDir := filepath.Dir(lmModel)

var textEncoderModel, ditModel, vaeModel string

for _, oo := range opts.Options {
Expand Down Expand Up @@ -52,6 +56,29 @@ func (a *AceStepCpp) Load(opts *pb.ModelOptions) error {
return fmt.Errorf("vae_model option is required")
}

// Resolve relative paths to the base directory
// If the path doesn't start with "/" it's relative
if !filepath.IsAbs(textEncoderModel) {
textEncoderModel = filepath.Join(baseDir, textEncoderModel)
}
if !filepath.IsAbs(ditModel) {
ditModel = filepath.Join(baseDir, ditModel)
}
if !filepath.IsAbs(vaeModel) {
vaeModel = filepath.Join(baseDir, vaeModel)
}

// Also resolve the lmModel if it's relative
if !filepath.IsAbs(lmModel) {
lmModel = filepath.Join(baseDir, lmModel)
}

fmt.Fprintf(os.Stderr, "[acestep-cpp] Resolved paths:\n")
fmt.Fprintf(os.Stderr, " LM Model: %s\n", lmModel)
fmt.Fprintf(os.Stderr, " Text Encoder: %s\n", textEncoderModel)
fmt.Fprintf(os.Stderr, " DiT Model: %s\n", ditModel)
fmt.Fprintf(os.Stderr, " VAE Model: %s\n", vaeModel)

if ret := CppLoadModel(lmModel, textEncoderModel, ditModel, vaeModel); ret != 0 {
return fmt.Errorf("failed to load acestep models (error code: %d)", ret)
}
Expand Down
Loading