Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions backend/go/acestep-cpp/acestepcpp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,17 @@ func TestLoadModel(t *testing.T) {
defer conn.Close()

client := pb.NewBackendClient(conn)

// Get base directory from main model file for relative paths
mainModelPath := filepath.Join(modelDir, "acestep-5Hz-lm-0.6B-Q8_0.gguf")
baseDir := filepath.Dir(mainModelPath)

resp, err := client.LoadModel(context.Background(), &pb.ModelOptions{
ModelFile: filepath.Join(modelDir, "acestep-5Hz-lm-0.6B-Q8_0.gguf"),
ModelFile: mainModelPath,
Options: []string{
"text_encoder_model:" + filepath.Join(modelDir, "Qwen3-Embedding-0.6B-Q8_0.gguf"),
"dit_model:" + filepath.Join(modelDir, "acestep-v15-turbo-Q8_0.gguf"),
"vae_model:" + filepath.Join(modelDir, "vae-BF16.gguf"),
baseDir + "/text_encoder_model:Qwen3-Embedding-0.6B-Q8_0.gguf",
baseDir + "/dit_model:acestep-v15-turbo-Q8_0.gguf",
baseDir + "/vae_model:vae-BF16.gguf",
},
})
if err != nil {
Expand Down Expand Up @@ -141,13 +146,17 @@ func TestSoundGeneration(t *testing.T) {

client := pb.NewBackendClient(conn)

// Get base directory from main model file for relative paths
mainModelPath := filepath.Join(modelDir, "acestep-5Hz-lm-0.6B-Q8_0.gguf")
baseDir := filepath.Dir(mainModelPath)

// Load models
loadResp, err := client.LoadModel(context.Background(), &pb.ModelOptions{
ModelFile: filepath.Join(modelDir, "acestep-5Hz-lm-0.6B-Q8_0.gguf"),
ModelFile: mainModelPath,
Options: []string{
"text_encoder_model:" + filepath.Join(modelDir, "Qwen3-Embedding-0.6B-Q8_0.gguf"),
"dit_model:" + filepath.Join(modelDir, "acestep-v15-turbo-Q8_0.gguf"),
"vae_model:" + filepath.Join(modelDir, "vae-BF16.gguf"),
baseDir + "/text_encoder_model:Qwen3-Embedding-0.6B-Q8_0.gguf",
baseDir + "/dit_model:acestep-v15-turbo-Q8_0.gguf",
baseDir + "/vae_model:vae-BF16.gguf",
},
})
if err != nil {
Expand All @@ -160,7 +169,7 @@ func TestSoundGeneration(t *testing.T) {
// Generate music
duration := float32(10.0)
temperature := float32(0.85)
bpm := int32(120)
bpm := float32(120.0)
caption := "A cheerful electronic dance track"
timesig := "4/4"

Expand Down
31 changes: 29 additions & 2 deletions backend/go/acestep-cpp/goacestepcpp.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package main
import (
"fmt"
"os"
"path/filepath"
"strings"

"github.com/mudler/LocalAI/pkg/grpc/base"
Expand All @@ -11,7 +12,7 @@ import (

var (
CppLoadModel func(lmModelPath, textEncoderPath, ditModelPath, vaeModelPath string) int
CppGenerateMusic func(caption, lyrics string, bpm int, keyscale, timesignature string, duration, temperature float32, instrumental bool, seed int, dst string, threads int) int
CppGenerateMusic func(caption, lyrics string, bpm float32, keyscale, timesignature string, duration, temperature float32, instrumental bool, seed int, dst string, threads int) int
)

type AceStepCpp struct {
Expand All @@ -22,6 +23,9 @@ func (a *AceStepCpp) Load(opts *pb.ModelOptions) error {
// ModelFile is the LM model path
lmModel := opts.ModelFile

// Get the base directory from ModelFile for resolving relative paths
baseDir := filepath.Dir(lmModel)

var textEncoderModel, ditModel, vaeModel string

for _, oo := range opts.Options {
Expand Down Expand Up @@ -52,6 +56,29 @@ func (a *AceStepCpp) Load(opts *pb.ModelOptions) error {
return fmt.Errorf("vae_model option is required")
}

// Resolve relative paths to the base directory
// If the path doesn't start with "/" it's relative
if !filepath.IsAbs(textEncoderModel) {
textEncoderModel = filepath.Join(baseDir, textEncoderModel)
}
if !filepath.IsAbs(ditModel) {
ditModel = filepath.Join(baseDir, ditModel)
}
if !filepath.IsAbs(vaeModel) {
vaeModel = filepath.Join(baseDir, vaeModel)
}

// Also resolve the lmModel if it's relative
if !filepath.IsAbs(lmModel) {
lmModel = filepath.Join(baseDir, lmModel)
}

fmt.Fprintf(os.Stderr, "[acestep-cpp] Resolved paths:\n")
fmt.Fprintf(os.Stderr, " LM Model: %s\n", lmModel)
fmt.Fprintf(os.Stderr, " Text Encoder: %s\n", textEncoderModel)
fmt.Fprintf(os.Stderr, " DiT Model: %s\n", ditModel)
fmt.Fprintf(os.Stderr, " VAE Model: %s\n", vaeModel)

if ret := CppLoadModel(lmModel, textEncoderModel, ditModel, vaeModel); ret != 0 {
return fmt.Errorf("failed to load acestep models (error code: %d)", ret)
}
Expand All @@ -74,7 +101,7 @@ func (a *AceStepCpp) SoundGeneration(req *pb.SoundGenerationRequest) error {
seed := 42
threads := 4

if ret := CppGenerateMusic(caption, lyrics, bpm, keyscale, timesignature, duration, temperature, instrumental, seed, req.GetDst(), threads); ret != 0 {
if ret := CppGenerateMusic(caption, lyrics, float32(bpm), keyscale, timesignature, duration, temperature, instrumental, seed, req.GetDst(), threads); ret != 0 {
return fmt.Errorf("failed to generate music (error code: %d)", ret)
}

Expand Down
Loading