diff --git a/pkg/backend/attach.go b/pkg/backend/attach.go index 4449af14..734230c3 100644 --- a/pkg/backend/attach.go +++ b/pkg/backend/attach.go @@ -75,9 +75,9 @@ func (b *backend) Attach(ctx context.Context, filepath string, cfg *config.Attac logrus.Infof("attach: loaded source model config [config: %+v]", srcModelConfig) - proc := b.getProcessor(cfg.DestinationDir, filepath, cfg.Raw) - if proc == nil { - return fmt.Errorf("failed to get processor for file %s", filepath) + proc, err := b.getProcessor(cfg.DestinationDir, filepath, cfg.Raw) + if err != nil { + return fmt.Errorf("failed to get processor: %w", err) } builder, err := b.getBuilder(cfg.Target, cfg) @@ -305,40 +305,44 @@ func (b *backend) getModelConfig(ctx context.Context, reference string, desc oci return &model, nil } -func (b *backend) getProcessor(destDir, filepath string, rawMediaType bool) processor.Processor { - if modelfile.IsFileType(filepath, modelfile.ConfigFilePatterns) { - mediaType := modelspec.MediaTypeModelWeightConfig +func (b *backend) getProcessor(destDir, filepath string, rawMediaType bool) (processor.Processor, error) { + info, err := os.Stat(filepath) + if err != nil { + return nil, fmt.Errorf("failed to stat file %s: %w", filepath, err) + } + + fileType := modelfile.InferFileType(filepath, info.Size()) + + var mediaType string + switch fileType { + case modelfile.FileTypeConfig: + mediaType = modelspec.MediaTypeModelWeightConfig if rawMediaType { mediaType = modelspec.MediaTypeModelWeightConfigRaw } - return processor.NewModelConfigProcessor(b.store, mediaType, []string{filepath}, destDir) - } - - if modelfile.IsFileType(filepath, modelfile.ModelFilePatterns) { - mediaType := modelspec.MediaTypeModelWeight + return processor.NewModelConfigProcessor(b.store, mediaType, []string{filepath}, destDir), nil + case modelfile.FileTypeModel: + mediaType = modelspec.MediaTypeModelWeight if rawMediaType { mediaType = modelspec.MediaTypeModelWeightRaw } - return processor.NewModelProcessor(b.store, mediaType, []string{filepath}, destDir) - } - - if modelfile.IsFileType(filepath, modelfile.CodeFilePatterns) { - mediaType := modelspec.MediaTypeModelCode + return processor.NewModelProcessor(b.store, mediaType, []string{filepath}, destDir), nil + case modelfile.FileTypeCode: + mediaType = modelspec.MediaTypeModelCode if rawMediaType { mediaType = modelspec.MediaTypeModelCodeRaw } - return processor.NewCodeProcessor(b.store, mediaType, []string{filepath}, destDir) - } - - if modelfile.IsFileType(filepath, modelfile.DocFilePatterns) { - mediaType := modelspec.MediaTypeModelDoc + return processor.NewCodeProcessor(b.store, mediaType, []string{filepath}, destDir), nil + case modelfile.FileTypeDoc: + mediaType = modelspec.MediaTypeModelDoc if rawMediaType { mediaType = modelspec.MediaTypeModelDocRaw } - return processor.NewDocProcessor(b.store, mediaType, []string{filepath}, destDir) + return processor.NewDocProcessor(b.store, mediaType, []string{filepath}, destDir), nil } - return nil + // Unreachable: InferFileType always returns a valid FileType. + return nil, fmt.Errorf("unexpected file type for %s", filepath) } func (b *backend) getBuilder(reference string, cfg *config.Attach) (build.Builder, error) { diff --git a/pkg/backend/attach_test.go b/pkg/backend/attach_test.go index d9df61b8..3099f412 100644 --- a/pkg/backend/attach_test.go +++ b/pkg/backend/attach_test.go @@ -20,14 +20,18 @@ import ( "context" "encoding/json" "fmt" + "os" + "path/filepath" "reflect" "testing" modelspec "github.com/modelpack/model-spec/specs-go/v1" ocispec "github.com/opencontainers/image-spec/specs-go/v1" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/modelpack/modctl/pkg/config" + "github.com/modelpack/modctl/pkg/modelfile" mockstore "github.com/modelpack/modctl/test/mocks/storage" ) @@ -60,30 +64,49 @@ func TestBackendGetManifest(t *testing.T) { func TestGetProcessor(t *testing.T) { b := &backend{store: &mockstore.Storage{}} + + tempDir := t.TempDir() + tests := []struct { - filepath string + name string + filename string + size int64 wantType string }{ - {"config.yaml", "modelConfigProcessor"}, - {"model.pth", "modelProcessor"}, - {"script.py", "codeProcessor"}, - {"doc.pdf", "docProcessor"}, - {"unknown.xyz", ""}, + {"config yaml", "config.yaml", 1024, "modelConfigProcessor"}, + {"model pth", "model.pth", 1024, "modelProcessor"}, + {"code python", "script.py", 1024, "codeProcessor"}, + {"doc pdf", "doc.pdf", 1024, "docProcessor"}, + {"unknown small fallback to code", "unknown.xyz", 1024, "codeProcessor"}, + {"dotfile small fallback to code", ".metadata", 1024, "codeProcessor"}, + {"unknown large fallback to model", "large_unknown", modelfile.WeightFileSizeThreshold + 1, "modelProcessor"}, } for _, tt := range tests { - t.Run(tt.filepath, func(t *testing.T) { - proc := b.getProcessor("", tt.filepath, false) - if tt.wantType == "" { - assert.Nil(t, proc) - } else { - assert.NotNil(t, proc) - assert.Contains(t, fmt.Sprintf("%T", proc), tt.wantType) - } + t.Run(tt.name, func(t *testing.T) { + fp := filepath.Join(tempDir, tt.filename) + f, err := os.Create(fp) + require.NoError(t, err) + require.NoError(t, f.Close()) + require.NoError(t, os.Truncate(fp, tt.size)) + + proc, err := b.getProcessor("", fp, false) + assert.NoError(t, err) + assert.NotNil(t, proc) + assert.Contains(t, fmt.Sprintf("%T", proc), tt.wantType) }) } } +func TestGetProcessorFileNotFound(t *testing.T) { + b := &backend{store: &mockstore.Storage{}} + + proc, err := b.getProcessor("", "/nonexistent/file.txt", false) + assert.Error(t, err) + assert.Nil(t, proc) + assert.Contains(t, err.Error(), "failed to stat file") +} + func TestSortLayers(t *testing.T) { testCases := []struct { name string diff --git a/pkg/backend/upload.go b/pkg/backend/upload.go index 78d6aefc..cacd32ba 100644 --- a/pkg/backend/upload.go +++ b/pkg/backend/upload.go @@ -31,9 +31,9 @@ import ( // Upload uploads the file to a model artifact repository in advance, but will not push config and manifest. func (b *backend) Upload(ctx context.Context, filepath string, cfg *config.Upload) error { logrus.Infof("upload: uploading file %s to %s", filepath, cfg.Repo) - proc := b.getProcessor(cfg.DestinationDir, filepath, cfg.Raw) - if proc == nil { - return fmt.Errorf("failed to get processor for file %s", filepath) + proc, err := b.getProcessor(cfg.DestinationDir, filepath, cfg.Raw) + if err != nil { + return fmt.Errorf("failed to get processor: %w", err) } opts := []build.Option{ diff --git a/pkg/modelfile/constants.go b/pkg/modelfile/constants.go index bdf16b94..1a38575f 100644 --- a/pkg/modelfile/constants.go +++ b/pkg/modelfile/constants.go @@ -431,6 +431,37 @@ var ( } ) +// FileType represents the inferred type of a file. +type FileType int + +const ( + FileTypeConfig FileType = iota + FileTypeModel + FileTypeCode + FileTypeDoc +) + +// InferFileType determines the file type by extension matching first, +// then falls back to a size-based heuristic for unrecognized files: +// >128MB -> FileTypeModel, otherwise -> FileTypeCode. +func InferFileType(filename string, fileSize int64) FileType { + switch { + case IsFileType(filename, ConfigFilePatterns): + return FileTypeConfig + case IsFileType(filename, ModelFilePatterns): + return FileTypeModel + case IsFileType(filename, CodeFilePatterns): + return FileTypeCode + case IsFileType(filename, DocFilePatterns): + return FileTypeDoc + default: + if SizeShouldBeWeightFile(fileSize) { + return FileTypeModel + } + return FileTypeCode + } +} + const ( // File size thresholds and workspace limits WeightFileSizeThreshold int64 = 128 * humanize.MByte // 128MB - threshold for considering file as weight file diff --git a/pkg/modelfile/constants_test.go b/pkg/modelfile/constants_test.go index cabaca45..77fa50e6 100644 --- a/pkg/modelfile/constants_test.go +++ b/pkg/modelfile/constants_test.go @@ -86,6 +86,52 @@ func TestIsFileTypeDocPatternsTfevents(t *testing.T) { } } +func TestInferFileType(t *testing.T) { + testCases := []struct { + name string + filename string + fileSize int64 + expected FileType + }{ + // Known extensions - size should not matter + {"config json", "config.json", 1024, FileTypeConfig}, + {"config yaml", "settings.yaml", 1024, FileTypeConfig}, + {"model safetensors", "model.safetensors", 1024, FileTypeModel}, + {"model bin", "weights.bin", 1024, FileTypeModel}, + {"code python", "script.py", 1024, FileTypeCode}, + {"code go", "main.go", 1024, FileTypeCode}, + {"doc markdown", "README.md", 1024, FileTypeDoc}, + {"doc pdf", "guide.pdf", 1024, FileTypeDoc}, + + // Dotfile with known secondary extension + {".cache.json is config", ".cache.json", 1024, FileTypeConfig}, + {".hidden.py is code", ".hidden.py", 1024, FileTypeCode}, + + // Unrecognized - small files fallback to code + {"dotfile small", ".metadata", 1024, FileTypeCode}, + {"no extension small", "unknown_file", 1024, FileTypeCode}, + {"unknown ext small", "data.xyz", 50 * 1024, FileTypeCode}, + + // Unrecognized - large files fallback to model + {"dotfile large", ".metadata", 200 * 1024 * 1024, FileTypeModel}, + {"no extension large", "unknown_file", 200 * 1024 * 1024, FileTypeModel}, + {"unknown ext large", "data.xyz", 200 * 1024 * 1024, FileTypeModel}, + + // Edge case: exactly at threshold (WeightFileSizeThreshold = 128*1000*1000) should be code + {"at threshold", "borderline", WeightFileSizeThreshold, FileTypeCode}, + // Just above threshold should be model + {"above threshold", "borderline", WeightFileSizeThreshold + 1, FileTypeModel}, + } + + assert := assert.New(t) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(tc.expected, InferFileType(tc.filename, tc.fileSize), + "InferFileType(%q, %d)", tc.filename, tc.fileSize) + }) + } +} + func TestIsSkippable(t *testing.T) { testCases := []struct { filename string diff --git a/pkg/modelfile/modelfile.go b/pkg/modelfile/modelfile.go index dd2604f9..e6773bea 100644 --- a/pkg/modelfile/modelfile.go +++ b/pkg/modelfile/modelfile.go @@ -310,24 +310,15 @@ func (mf *modelfile) generateByWorkspace(config *configmodelfile.GenerateConfig) return fmt.Errorf("workspace exceeds maximum total size limit of %d bytes (%s)", MaxTotalWorkspaceSize, formatBytes(MaxTotalWorkspaceSize)) } - switch { - case IsFileType(filename, ConfigFilePatterns): + switch InferFileType(filename, info.Size()) { + case FileTypeConfig: mf.config.Add(relPath) - case IsFileType(filename, ModelFilePatterns): + case FileTypeModel: mf.model.Add(relPath) - case IsFileType(filename, CodeFilePatterns): + case FileTypeCode: mf.code.Add(relPath) - case IsFileType(filename, DocFilePatterns): + case FileTypeDoc: mf.doc.Add(relPath) - default: - // If the file is large, usually it is a weight file. - if SizeShouldBeWeightFile(info.Size()) { - mf.model.Add(relPath) - } else { - mf.code.Add(relPath) - } - - return nil } return nil