Skip to content

Commit 1e78590

Browse files
committed
fix: restore legacy schema gen for cog train/predict/serve (skipLabels path)
The skipLabels optimization (8b1c141) skipped the entire post-build phase including legacy schema generation. This broke cog train/predict/serve which need the schema for -i flag parsing and input validation. Move legacy schema gen above the skipLabels early return and add a minimal second Docker build that bundles only the schema file (no labels, pip freeze, or git info). Restore the sourceDir parameter on GenerateOpenAPISchema so ExcludeSource builds can volume-mount the project directory for Python introspection. Re-enable the train_basic and training_setup integration tests that were temporarily skipped.
1 parent f434680 commit 1e78590

4 files changed

Lines changed: 43 additions & 16 deletions

File tree

integration-tests/tests/train_basic.txtar

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
skip 'cog train requires static schema gen which is gated behind COG_STATIC_SCHEMA=1'
21

32
# Test basic training functionality
43

integration-tests/tests/training_setup.txtar

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
skip 'cog train requires static schema gen which is gated behind COG_STATIC_SCHEMA=1'
21

32
# Test that training with setup method works correctly
43

pkg/image/build.go

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -258,22 +258,21 @@ func Build(
258258
}
259259
}
260260

261-
// When skipLabels is true (cog run/predict/serve/train), skip the expensive
262-
// label-adding phase. This image is for local use only and won't be distributed,
263-
// so we don't need metadata labels, pip freeze, schema bundling, or git info.
264-
if skipLabels {
265-
return tmpImageId, nil
266-
}
267-
268-
console.Info("")
269-
270261
// --- Post-build legacy schema generation ---
271262
// For SDK < 0.17.0 (or when static gen was not used), generate the schema
272263
// by running the built image with python -m cog.command.openapi_schema.
264+
// This must run before the skipLabels early return so that cog train/predict/serve
265+
// have a schema available for input validation and -i flag parsing.
273266
if len(schemaJSON) == 0 && !skipSchemaValidation {
274267
console.Info("Validating model schema...")
275268
enableGPU := cfg.Build != nil && cfg.Build.GPU
276-
legacySchema, err := GenerateOpenAPISchema(ctx, dockerCommand, tmpImageId, enableGPU)
269+
// When excludeSource is true (cog serve/predict/train), /src was not
270+
// COPYed into the image, so volume-mount the project directory.
271+
sourceDir := ""
272+
if excludeSource {
273+
sourceDir = dir
274+
}
275+
legacySchema, err := GenerateOpenAPISchema(ctx, dockerCommand, tmpImageId, enableGPU, sourceDir)
277276
if err != nil {
278277
return "", fmt.Errorf("Failed to get type signature: %w", err)
279278
}
@@ -288,6 +287,28 @@ func Build(
288287
}
289288
}
290289

290+
// When skipLabels is true (cog run/predict/serve/train), skip the expensive
291+
// label-adding phase. This image is for local use only and won't be distributed,
292+
// so we don't need metadata labels, pip freeze, or git info.
293+
// We still need the schema bundled, so do a minimal second build to add it.
294+
if skipLabels {
295+
if len(schemaJSON) > 0 {
296+
// Use trailing "/" on the destination so Docker creates the .cog/
297+
// directory even in ExcludeSource images where COPY . /src was
298+
// skipped and .cog/ does not yet exist.
299+
schemaDockerfile := fmt.Sprintf("FROM %s\nCOPY %s .cog/\n", tmpImageId, bundledSchemaFile)
300+
buildOpts := command.ImageBuildOptions{
301+
DockerfileContents: schemaDockerfile,
302+
ImageName: tmpImageId,
303+
ProgressOutput: progressOutput,
304+
}
305+
if _, err := dockerCommand.ImageBuild(ctx, buildOpts); err != nil {
306+
return "", fmt.Errorf("Failed to bundle schema into image: %w", err)
307+
}
308+
}
309+
return tmpImageId, nil
310+
}
311+
291312
console.Info("Adding labels to image...")
292313
console.Info("")
293314

pkg/image/openapi_schema.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@ import (
1414
// image with `python -m cog.command.openapi_schema`. This is the legacy path used
1515
// for SDK versions < 0.17.0 where the schema must be generated at runtime via
1616
// pydantic introspection rather than static analysis.
17-
func GenerateOpenAPISchema(ctx context.Context, dockerClient command.Command, imageName string, enableGPU bool) (map[string]any, error) {
17+
//
18+
// sourceDir, when non-empty, is volume-mounted as /src. This is needed for
19+
// ExcludeSource builds (cog serve/predict/train) where COPY . /src was skipped.
20+
func GenerateOpenAPISchema(ctx context.Context, dockerClient command.Command, imageName string, enableGPU bool, sourceDir string) (map[string]any, error) {
1821
console.Debugf("=== image.GenerateOpenAPISchema %s", imageName)
1922
var stdout bytes.Buffer
2023
var stderr bytes.Buffer
@@ -24,19 +27,24 @@ func GenerateOpenAPISchema(ctx context.Context, dockerClient command.Command, im
2427
gpus = "all"
2528
}
2629

27-
err := docker.RunWithIO(ctx, dockerClient, command.RunOptions{
30+
runOpts := command.RunOptions{
2831
Image: imageName,
2932
Args: []string{
3033
"python", "-m", "cog.command.openapi_schema",
3134
},
3235
GPUs: gpus,
33-
}, nil, &stdout, &stderr)
36+
}
37+
if sourceDir != "" {
38+
runOpts.Volumes = []command.Volume{{Source: sourceDir, Destination: "/src"}}
39+
}
40+
41+
err := docker.RunWithIO(ctx, dockerClient, runOpts, nil, &stdout, &stderr)
3442

3543
if enableGPU && err == docker.ErrMissingDeviceDriver {
3644
console.Debug(stdout.String())
3745
console.Debug(stderr.String())
3846
console.Debug("Missing device driver, re-trying without GPU")
39-
return GenerateOpenAPISchema(ctx, dockerClient, imageName, false)
47+
return GenerateOpenAPISchema(ctx, dockerClient, imageName, false, sourceDir)
4048
}
4149

4250
if err != nil {

0 commit comments

Comments
 (0)