Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 27 additions & 23 deletions pkg/backend/attach.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ func (b *backend) Attach(ctx context.Context, filepath string, cfg *config.Attac

logrus.Infof("attach: loaded source model config [config: %+v]", srcModelConfig)

proc := b.getProcessor(cfg.DestinationDir, filepath, cfg.Raw)
if proc == nil {
return fmt.Errorf("failed to get processor for file %s", filepath)
proc, err := b.getProcessor(cfg.DestinationDir, filepath, cfg.Raw)
if err != nil {
return fmt.Errorf("failed to get processor: %w", err)
}

builder, err := b.getBuilder(cfg.Target, cfg)
Expand Down Expand Up @@ -305,40 +305,44 @@ func (b *backend) getModelConfig(ctx context.Context, reference string, desc oci
return &model, nil
}

func (b *backend) getProcessor(destDir, filepath string, rawMediaType bool) processor.Processor {
if modelfile.IsFileType(filepath, modelfile.ConfigFilePatterns) {
mediaType := modelspec.MediaTypeModelWeightConfig
func (b *backend) getProcessor(destDir, filepath string, rawMediaType bool) (processor.Processor, error) {
info, err := os.Stat(filepath)
if err != nil {
return nil, fmt.Errorf("failed to stat file %s: %w", filepath, err)
}

fileType := modelfile.InferFileType(filepath, info.Size())

var mediaType string
switch fileType {
case modelfile.FileTypeConfig:
mediaType = modelspec.MediaTypeModelWeightConfig
if rawMediaType {
mediaType = modelspec.MediaTypeModelWeightConfigRaw
}
return processor.NewModelConfigProcessor(b.store, mediaType, []string{filepath}, destDir)
}

if modelfile.IsFileType(filepath, modelfile.ModelFilePatterns) {
mediaType := modelspec.MediaTypeModelWeight
return processor.NewModelConfigProcessor(b.store, mediaType, []string{filepath}, destDir), nil
case modelfile.FileTypeModel:
mediaType = modelspec.MediaTypeModelWeight
if rawMediaType {
mediaType = modelspec.MediaTypeModelWeightRaw
}
return processor.NewModelProcessor(b.store, mediaType, []string{filepath}, destDir)
}

if modelfile.IsFileType(filepath, modelfile.CodeFilePatterns) {
mediaType := modelspec.MediaTypeModelCode
return processor.NewModelProcessor(b.store, mediaType, []string{filepath}, destDir), nil
case modelfile.FileTypeCode:
mediaType = modelspec.MediaTypeModelCode
if rawMediaType {
mediaType = modelspec.MediaTypeModelCodeRaw
}
return processor.NewCodeProcessor(b.store, mediaType, []string{filepath}, destDir)
}

if modelfile.IsFileType(filepath, modelfile.DocFilePatterns) {
mediaType := modelspec.MediaTypeModelDoc
return processor.NewCodeProcessor(b.store, mediaType, []string{filepath}, destDir), nil
case modelfile.FileTypeDoc:
mediaType = modelspec.MediaTypeModelDoc
if rawMediaType {
mediaType = modelspec.MediaTypeModelDocRaw
}
return processor.NewDocProcessor(b.store, mediaType, []string{filepath}, destDir)
return processor.NewDocProcessor(b.store, mediaType, []string{filepath}, destDir), nil
}

return nil
// Unreachable: InferFileType always returns a valid FileType.
return nil, fmt.Errorf("unexpected file type for %s", filepath)
}

func (b *backend) getBuilder(reference string, cfg *config.Attach) (build.Builder, error) {
Expand Down
51 changes: 37 additions & 14 deletions pkg/backend/attach_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,18 @@ import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"reflect"
"testing"

modelspec "github.com/modelpack/model-spec/specs-go/v1"
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/modelpack/modctl/pkg/config"
"github.com/modelpack/modctl/pkg/modelfile"
mockstore "github.com/modelpack/modctl/test/mocks/storage"
)

Expand Down Expand Up @@ -60,30 +64,49 @@ func TestBackendGetManifest(t *testing.T) {

func TestGetProcessor(t *testing.T) {
b := &backend{store: &mockstore.Storage{}}

tempDir := t.TempDir()

tests := []struct {
filepath string
name string
filename string
size int64
wantType string
}{
{"config.yaml", "modelConfigProcessor"},
{"model.pth", "modelProcessor"},
{"script.py", "codeProcessor"},
{"doc.pdf", "docProcessor"},
{"unknown.xyz", ""},
{"config yaml", "config.yaml", 1024, "modelConfigProcessor"},
{"model pth", "model.pth", 1024, "modelProcessor"},
{"code python", "script.py", 1024, "codeProcessor"},
{"doc pdf", "doc.pdf", 1024, "docProcessor"},
{"unknown small fallback to code", "unknown.xyz", 1024, "codeProcessor"},
{"dotfile small fallback to code", ".metadata", 1024, "codeProcessor"},
{"unknown large fallback to model", "large_unknown", modelfile.WeightFileSizeThreshold + 1, "modelProcessor"},
}

for _, tt := range tests {
t.Run(tt.filepath, func(t *testing.T) {
proc := b.getProcessor("", tt.filepath, false)
if tt.wantType == "" {
assert.Nil(t, proc)
} else {
assert.NotNil(t, proc)
assert.Contains(t, fmt.Sprintf("%T", proc), tt.wantType)
}
t.Run(tt.name, func(t *testing.T) {
fp := filepath.Join(tempDir, tt.filename)
f, err := os.Create(fp)
require.NoError(t, err)
require.NoError(t, f.Close())
require.NoError(t, os.Truncate(fp, tt.size))

proc, err := b.getProcessor("", fp, false)
assert.NoError(t, err)
assert.NotNil(t, proc)
assert.Contains(t, fmt.Sprintf("%T", proc), tt.wantType)
})
}
}

func TestGetProcessorFileNotFound(t *testing.T) {
b := &backend{store: &mockstore.Storage{}}

proc, err := b.getProcessor("", "/nonexistent/file.txt", false)
assert.Error(t, err)
assert.Nil(t, proc)
assert.Contains(t, err.Error(), "failed to stat file")
}

func TestSortLayers(t *testing.T) {
testCases := []struct {
name string
Expand Down
6 changes: 3 additions & 3 deletions pkg/backend/upload.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ import (
// Upload uploads the file to a model artifact repository in advance, but will not push config and manifest.
func (b *backend) Upload(ctx context.Context, filepath string, cfg *config.Upload) error {
logrus.Infof("upload: uploading file %s to %s", filepath, cfg.Repo)
proc := b.getProcessor(cfg.DestinationDir, filepath, cfg.Raw)
if proc == nil {
return fmt.Errorf("failed to get processor for file %s", filepath)
proc, err := b.getProcessor(cfg.DestinationDir, filepath, cfg.Raw)
if err != nil {
return fmt.Errorf("failed to get processor: %w", err)
}

opts := []build.Option{
Expand Down
31 changes: 31 additions & 0 deletions pkg/modelfile/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,37 @@ var (
}
)

// FileType represents the inferred type of a file.
type FileType int

const (
FileTypeConfig FileType = iota
FileTypeModel
FileTypeCode
FileTypeDoc
)

// InferFileType determines the file type by extension matching first,
// then falls back to a size-based heuristic for unrecognized files:
// >128MB -> FileTypeModel, otherwise -> FileTypeCode.
func InferFileType(filename string, fileSize int64) FileType {
switch {
case IsFileType(filename, ConfigFilePatterns):
return FileTypeConfig
case IsFileType(filename, ModelFilePatterns):
return FileTypeModel
case IsFileType(filename, CodeFilePatterns):
return FileTypeCode
case IsFileType(filename, DocFilePatterns):
return FileTypeDoc
default:
if SizeShouldBeWeightFile(fileSize) {
return FileTypeModel
}
return FileTypeCode
}
}

const (
// File size thresholds and workspace limits
WeightFileSizeThreshold int64 = 128 * humanize.MByte // 128MB - threshold for considering file as weight file
Expand Down
46 changes: 46 additions & 0 deletions pkg/modelfile/constants_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,52 @@ func TestIsFileTypeDocPatternsTfevents(t *testing.T) {
}
}

func TestInferFileType(t *testing.T) {
testCases := []struct {
name string
filename string
fileSize int64
expected FileType
}{
// Known extensions - size should not matter
{"config json", "config.json", 1024, FileTypeConfig},
{"config yaml", "settings.yaml", 1024, FileTypeConfig},
{"model safetensors", "model.safetensors", 1024, FileTypeModel},
{"model bin", "weights.bin", 1024, FileTypeModel},
{"code python", "script.py", 1024, FileTypeCode},
{"code go", "main.go", 1024, FileTypeCode},
{"doc markdown", "README.md", 1024, FileTypeDoc},
{"doc pdf", "guide.pdf", 1024, FileTypeDoc},

// Dotfile with known secondary extension
{".cache.json is config", ".cache.json", 1024, FileTypeConfig},
{".hidden.py is code", ".hidden.py", 1024, FileTypeCode},

// Unrecognized - small files fallback to code
{"dotfile small", ".metadata", 1024, FileTypeCode},
{"no extension small", "unknown_file", 1024, FileTypeCode},
{"unknown ext small", "data.xyz", 50 * 1024, FileTypeCode},

// Unrecognized - large files fallback to model
{"dotfile large", ".metadata", 200 * 1024 * 1024, FileTypeModel},
{"no extension large", "unknown_file", 200 * 1024 * 1024, FileTypeModel},
{"unknown ext large", "data.xyz", 200 * 1024 * 1024, FileTypeModel},

// Edge case: exactly at threshold (WeightFileSizeThreshold = 128*1000*1000) should be code
{"at threshold", "borderline", WeightFileSizeThreshold, FileTypeCode},
// Just above threshold should be model
{"above threshold", "borderline", WeightFileSizeThreshold + 1, FileTypeModel},
}

assert := assert.New(t)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(tc.expected, InferFileType(tc.filename, tc.fileSize),
"InferFileType(%q, %d)", tc.filename, tc.fileSize)
})
}
}

func TestIsSkippable(t *testing.T) {
testCases := []struct {
filename string
Expand Down
19 changes: 5 additions & 14 deletions pkg/modelfile/modelfile.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,24 +310,15 @@ func (mf *modelfile) generateByWorkspace(config *configmodelfile.GenerateConfig)
return fmt.Errorf("workspace exceeds maximum total size limit of %d bytes (%s)", MaxTotalWorkspaceSize, formatBytes(MaxTotalWorkspaceSize))
}

switch {
case IsFileType(filename, ConfigFilePatterns):
switch InferFileType(filename, info.Size()) {
case FileTypeConfig:
mf.config.Add(relPath)
case IsFileType(filename, ModelFilePatterns):
case FileTypeModel:
mf.model.Add(relPath)
case IsFileType(filename, CodeFilePatterns):
case FileTypeCode:
mf.code.Add(relPath)
case IsFileType(filename, DocFilePatterns):
case FileTypeDoc:
mf.doc.Add(relPath)
default:
// If the file is large, usually it is a weight file.
if SizeShouldBeWeightFile(info.Size()) {
mf.model.Add(relPath)
} else {
mf.code.Add(relPath)
}

return nil
}

return nil
Expand Down
Loading