Skip to content

Commit 1c9e7ff

Browse files
authored
fix: use InferFileType in getProcessor for unrecognized file extensions (#499)
* feat(modelfile): add InferFileType with size-based heuristic fallback Adds FileType enum and InferFileType function that combines extension pattern matching with a size-based heuristic for unrecognized files (>128MB -> model weight, otherwise -> code). Refs: #497 Signed-off-by: Zhao Chen <winters.zc@antgroup.com> * refactor(modelfile): use InferFileType in workspace scanner Replace inline switch-case with InferFileType call, eliminating duplicated file type classification logic. Refs: #497 Signed-off-by: Zhao Chen <winters.zc@antgroup.com> * fix(backend): use InferFileType in getProcessor for unrecognized files getProcessor now calls modelfile.InferFileType instead of returning nil for unrecognized file extensions. Unrecognized files fall back to size-based heuristic: >128MB treated as model weight, otherwise as code. Signature changed to return error (for os.Stat failure). Fixes: #497 Signed-off-by: Zhao Chen <winters.zc@antgroup.com> * fix(backend): remove redundant filename from error wrapping getProcessor already includes the filename in its error messages, so the callers in Attach and Upload no longer duplicate it. Signed-off-by: Zhao Chen <winters.zc@antgroup.com> --------- Signed-off-by: Zhao Chen <winters.zc@antgroup.com>
1 parent 215690c commit 1c9e7ff

6 files changed

Lines changed: 149 additions & 54 deletions

File tree

pkg/backend/attach.go

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ func (b *backend) Attach(ctx context.Context, filepath string, cfg *config.Attac
7575

7676
logrus.Infof("attach: loaded source model config [config: %+v]", srcModelConfig)
7777

78-
proc := b.getProcessor(cfg.DestinationDir, filepath, cfg.Raw)
79-
if proc == nil {
80-
return fmt.Errorf("failed to get processor for file %s", filepath)
78+
proc, err := b.getProcessor(cfg.DestinationDir, filepath, cfg.Raw)
79+
if err != nil {
80+
return fmt.Errorf("failed to get processor: %w", err)
8181
}
8282

8383
builder, err := b.getBuilder(cfg.Target, cfg)
@@ -305,40 +305,44 @@ func (b *backend) getModelConfig(ctx context.Context, reference string, desc oci
305305
return &model, nil
306306
}
307307

308-
func (b *backend) getProcessor(destDir, filepath string, rawMediaType bool) processor.Processor {
309-
if modelfile.IsFileType(filepath, modelfile.ConfigFilePatterns) {
310-
mediaType := modelspec.MediaTypeModelWeightConfig
308+
func (b *backend) getProcessor(destDir, filepath string, rawMediaType bool) (processor.Processor, error) {
309+
info, err := os.Stat(filepath)
310+
if err != nil {
311+
return nil, fmt.Errorf("failed to stat file %s: %w", filepath, err)
312+
}
313+
314+
fileType := modelfile.InferFileType(filepath, info.Size())
315+
316+
var mediaType string
317+
switch fileType {
318+
case modelfile.FileTypeConfig:
319+
mediaType = modelspec.MediaTypeModelWeightConfig
311320
if rawMediaType {
312321
mediaType = modelspec.MediaTypeModelWeightConfigRaw
313322
}
314-
return processor.NewModelConfigProcessor(b.store, mediaType, []string{filepath}, destDir)
315-
}
316-
317-
if modelfile.IsFileType(filepath, modelfile.ModelFilePatterns) {
318-
mediaType := modelspec.MediaTypeModelWeight
323+
return processor.NewModelConfigProcessor(b.store, mediaType, []string{filepath}, destDir), nil
324+
case modelfile.FileTypeModel:
325+
mediaType = modelspec.MediaTypeModelWeight
319326
if rawMediaType {
320327
mediaType = modelspec.MediaTypeModelWeightRaw
321328
}
322-
return processor.NewModelProcessor(b.store, mediaType, []string{filepath}, destDir)
323-
}
324-
325-
if modelfile.IsFileType(filepath, modelfile.CodeFilePatterns) {
326-
mediaType := modelspec.MediaTypeModelCode
329+
return processor.NewModelProcessor(b.store, mediaType, []string{filepath}, destDir), nil
330+
case modelfile.FileTypeCode:
331+
mediaType = modelspec.MediaTypeModelCode
327332
if rawMediaType {
328333
mediaType = modelspec.MediaTypeModelCodeRaw
329334
}
330-
return processor.NewCodeProcessor(b.store, mediaType, []string{filepath}, destDir)
331-
}
332-
333-
if modelfile.IsFileType(filepath, modelfile.DocFilePatterns) {
334-
mediaType := modelspec.MediaTypeModelDoc
335+
return processor.NewCodeProcessor(b.store, mediaType, []string{filepath}, destDir), nil
336+
case modelfile.FileTypeDoc:
337+
mediaType = modelspec.MediaTypeModelDoc
335338
if rawMediaType {
336339
mediaType = modelspec.MediaTypeModelDocRaw
337340
}
338-
return processor.NewDocProcessor(b.store, mediaType, []string{filepath}, destDir)
341+
return processor.NewDocProcessor(b.store, mediaType, []string{filepath}, destDir), nil
339342
}
340343

341-
return nil
344+
// Unreachable: InferFileType always returns a valid FileType.
345+
return nil, fmt.Errorf("unexpected file type for %s", filepath)
342346
}
343347

344348
func (b *backend) getBuilder(reference string, cfg *config.Attach) (build.Builder, error) {

pkg/backend/attach_test.go

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,18 @@ import (
2020
"context"
2121
"encoding/json"
2222
"fmt"
23+
"os"
24+
"path/filepath"
2325
"reflect"
2426
"testing"
2527

2628
modelspec "github.com/modelpack/model-spec/specs-go/v1"
2729
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
2830
"github.com/stretchr/testify/assert"
31+
"github.com/stretchr/testify/require"
2932

3033
"github.com/modelpack/modctl/pkg/config"
34+
"github.com/modelpack/modctl/pkg/modelfile"
3135
mockstore "github.com/modelpack/modctl/test/mocks/storage"
3236
)
3337

@@ -60,30 +64,49 @@ func TestBackendGetManifest(t *testing.T) {
6064

6165
func TestGetProcessor(t *testing.T) {
6266
b := &backend{store: &mockstore.Storage{}}
67+
68+
tempDir := t.TempDir()
69+
6370
tests := []struct {
64-
filepath string
71+
name string
72+
filename string
73+
size int64
6574
wantType string
6675
}{
67-
{"config.yaml", "modelConfigProcessor"},
68-
{"model.pth", "modelProcessor"},
69-
{"script.py", "codeProcessor"},
70-
{"doc.pdf", "docProcessor"},
71-
{"unknown.xyz", ""},
76+
{"config yaml", "config.yaml", 1024, "modelConfigProcessor"},
77+
{"model pth", "model.pth", 1024, "modelProcessor"},
78+
{"code python", "script.py", 1024, "codeProcessor"},
79+
{"doc pdf", "doc.pdf", 1024, "docProcessor"},
80+
{"unknown small fallback to code", "unknown.xyz", 1024, "codeProcessor"},
81+
{"dotfile small fallback to code", ".metadata", 1024, "codeProcessor"},
82+
{"unknown large fallback to model", "large_unknown", modelfile.WeightFileSizeThreshold + 1, "modelProcessor"},
7283
}
7384

7485
for _, tt := range tests {
75-
t.Run(tt.filepath, func(t *testing.T) {
76-
proc := b.getProcessor("", tt.filepath, false)
77-
if tt.wantType == "" {
78-
assert.Nil(t, proc)
79-
} else {
80-
assert.NotNil(t, proc)
81-
assert.Contains(t, fmt.Sprintf("%T", proc), tt.wantType)
82-
}
86+
t.Run(tt.name, func(t *testing.T) {
87+
fp := filepath.Join(tempDir, tt.filename)
88+
f, err := os.Create(fp)
89+
require.NoError(t, err)
90+
require.NoError(t, f.Close())
91+
require.NoError(t, os.Truncate(fp, tt.size))
92+
93+
proc, err := b.getProcessor("", fp, false)
94+
assert.NoError(t, err)
95+
assert.NotNil(t, proc)
96+
assert.Contains(t, fmt.Sprintf("%T", proc), tt.wantType)
8397
})
8498
}
8599
}
86100

101+
func TestGetProcessorFileNotFound(t *testing.T) {
102+
b := &backend{store: &mockstore.Storage{}}
103+
104+
proc, err := b.getProcessor("", "/nonexistent/file.txt", false)
105+
assert.Error(t, err)
106+
assert.Nil(t, proc)
107+
assert.Contains(t, err.Error(), "failed to stat file")
108+
}
109+
87110
func TestSortLayers(t *testing.T) {
88111
testCases := []struct {
89112
name string

pkg/backend/upload.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ import (
3131
// Upload uploads the file to a model artifact repository in advance, but will not push config and manifest.
3232
func (b *backend) Upload(ctx context.Context, filepath string, cfg *config.Upload) error {
3333
logrus.Infof("upload: uploading file %s to %s", filepath, cfg.Repo)
34-
proc := b.getProcessor(cfg.DestinationDir, filepath, cfg.Raw)
35-
if proc == nil {
36-
return fmt.Errorf("failed to get processor for file %s", filepath)
34+
proc, err := b.getProcessor(cfg.DestinationDir, filepath, cfg.Raw)
35+
if err != nil {
36+
return fmt.Errorf("failed to get processor: %w", err)
3737
}
3838

3939
opts := []build.Option{

pkg/modelfile/constants.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,37 @@ var (
431431
}
432432
)
433433

434+
// FileType represents the inferred type of a file.
435+
type FileType int
436+
437+
const (
438+
FileTypeConfig FileType = iota
439+
FileTypeModel
440+
FileTypeCode
441+
FileTypeDoc
442+
)
443+
444+
// InferFileType determines the file type by extension matching first,
445+
// then falls back to a size-based heuristic for unrecognized files:
446+
// >128MB -> FileTypeModel, otherwise -> FileTypeCode.
447+
func InferFileType(filename string, fileSize int64) FileType {
448+
switch {
449+
case IsFileType(filename, ConfigFilePatterns):
450+
return FileTypeConfig
451+
case IsFileType(filename, ModelFilePatterns):
452+
return FileTypeModel
453+
case IsFileType(filename, CodeFilePatterns):
454+
return FileTypeCode
455+
case IsFileType(filename, DocFilePatterns):
456+
return FileTypeDoc
457+
default:
458+
if SizeShouldBeWeightFile(fileSize) {
459+
return FileTypeModel
460+
}
461+
return FileTypeCode
462+
}
463+
}
464+
434465
const (
435466
// File size thresholds and workspace limits
436467
WeightFileSizeThreshold int64 = 128 * humanize.MByte // 128MB - threshold for considering file as weight file

pkg/modelfile/constants_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,52 @@ func TestIsFileTypeDocPatternsTfevents(t *testing.T) {
8686
}
8787
}
8888

89+
func TestInferFileType(t *testing.T) {
90+
testCases := []struct {
91+
name string
92+
filename string
93+
fileSize int64
94+
expected FileType
95+
}{
96+
// Known extensions - size should not matter
97+
{"config json", "config.json", 1024, FileTypeConfig},
98+
{"config yaml", "settings.yaml", 1024, FileTypeConfig},
99+
{"model safetensors", "model.safetensors", 1024, FileTypeModel},
100+
{"model bin", "weights.bin", 1024, FileTypeModel},
101+
{"code python", "script.py", 1024, FileTypeCode},
102+
{"code go", "main.go", 1024, FileTypeCode},
103+
{"doc markdown", "README.md", 1024, FileTypeDoc},
104+
{"doc pdf", "guide.pdf", 1024, FileTypeDoc},
105+
106+
// Dotfile with known secondary extension
107+
{".cache.json is config", ".cache.json", 1024, FileTypeConfig},
108+
{".hidden.py is code", ".hidden.py", 1024, FileTypeCode},
109+
110+
// Unrecognized - small files fallback to code
111+
{"dotfile small", ".metadata", 1024, FileTypeCode},
112+
{"no extension small", "unknown_file", 1024, FileTypeCode},
113+
{"unknown ext small", "data.xyz", 50 * 1024, FileTypeCode},
114+
115+
// Unrecognized - large files fallback to model
116+
{"dotfile large", ".metadata", 200 * 1024 * 1024, FileTypeModel},
117+
{"no extension large", "unknown_file", 200 * 1024 * 1024, FileTypeModel},
118+
{"unknown ext large", "data.xyz", 200 * 1024 * 1024, FileTypeModel},
119+
120+
// Edge case: exactly at threshold (WeightFileSizeThreshold = 128*1000*1000) should be code
121+
{"at threshold", "borderline", WeightFileSizeThreshold, FileTypeCode},
122+
// Just above threshold should be model
123+
{"above threshold", "borderline", WeightFileSizeThreshold + 1, FileTypeModel},
124+
}
125+
126+
assert := assert.New(t)
127+
for _, tc := range testCases {
128+
t.Run(tc.name, func(t *testing.T) {
129+
assert.Equal(tc.expected, InferFileType(tc.filename, tc.fileSize),
130+
"InferFileType(%q, %d)", tc.filename, tc.fileSize)
131+
})
132+
}
133+
}
134+
89135
func TestIsSkippable(t *testing.T) {
90136
testCases := []struct {
91137
filename string

pkg/modelfile/modelfile.go

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -310,24 +310,15 @@ func (mf *modelfile) generateByWorkspace(config *configmodelfile.GenerateConfig)
310310
return fmt.Errorf("workspace exceeds maximum total size limit of %d bytes (%s)", MaxTotalWorkspaceSize, formatBytes(MaxTotalWorkspaceSize))
311311
}
312312

313-
switch {
314-
case IsFileType(filename, ConfigFilePatterns):
313+
switch InferFileType(filename, info.Size()) {
314+
case FileTypeConfig:
315315
mf.config.Add(relPath)
316-
case IsFileType(filename, ModelFilePatterns):
316+
case FileTypeModel:
317317
mf.model.Add(relPath)
318-
case IsFileType(filename, CodeFilePatterns):
318+
case FileTypeCode:
319319
mf.code.Add(relPath)
320-
case IsFileType(filename, DocFilePatterns):
320+
case FileTypeDoc:
321321
mf.doc.Add(relPath)
322-
default:
323-
// If the file is large, usually it is a weight file.
324-
if SizeShouldBeWeightFile(info.Size()) {
325-
mf.model.Add(relPath)
326-
} else {
327-
mf.code.Add(relPath)
328-
}
329-
330-
return nil
331322
}
332323

333324
return nil

0 commit comments

Comments
 (0)