Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 42 additions & 30 deletions pkg/distribution/oci/remote/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -693,8 +693,21 @@ func (l *remoteLayer) MediaType() (oci.MediaType, error) {
return l.desc.MediaType, nil
}

// syncWriter is a thread-safe wrapper around io.Writer for concurrent writes
type syncWriter struct {
w io.Writer
mu sync.Mutex
}

// Write implements io.Writer interface with mutex protection
func (sw *syncWriter) Write(p []byte) (n int, err error) {
sw.mu.Lock()
defer sw.mu.Unlock()
return sw.w.Write(p)
}

// Write pushes an image to a registry.
func Write(ref reference.Reference, img oci.Image, opts ...Option) error {
func Write(ref reference.Reference, img oci.Image, w io.Writer, opts ...Option) error {
o := makeOptions(opts...)

// Pre-authorize with push scope to ensure we have the right permissions
Expand Down Expand Up @@ -724,6 +737,12 @@ func Write(ref reference.Reference, img oci.Image, opts ...Option) error {
totalSize += size
}

// Create a thread-safe writer wrapper for concurrent progress reporting
var safeWriter io.Writer
if w != nil {
safeWriter = &syncWriter{w: w}
}

var completed int64
for _, layer := range layers {
digest, err := layer.Digest()
Expand All @@ -747,6 +766,13 @@ func Write(ref reference.Reference, img oci.Image, opts ...Option) error {
Size: size,
}

var pr *progress.Reporter
var progressChan chan<- oci.Update
if safeWriter != nil {
pr = progress.NewProgressReporter(safeWriter, progress.PushMsg, size, layer)
progressChan = pr.Updates()
}

rc, err := layer.Compressed()
if err != nil {
return fmt.Errorf("getting layer content: %w", err)
Expand All @@ -759,57 +785,58 @@ func Write(ref reference.Reference, img oci.Image, opts ...Option) error {
// If already exists, continue
if errdefs.IsAlreadyExists(err) || strings.Contains(err.Error(), "already exists") {
completed += size
if o.progress != nil {
o.progress <- oci.Update{
if progressChan != nil {
progressChan <- oci.Update{
Complete: completed,
Total: totalSize,
Total: size,
}
}
Comment thread
MetsysEht marked this conversation as resolved.
continue
}
closeProgress(o.progress)
closeProgress(progressChan)
return fmt.Errorf("pushing layer: %w", err)
}

// Wrap the reader with progress tracking to report incremental upload progress
// Uses the shared progress.Reader from internal/progress package
var reader io.Reader = rc
if o.progress != nil {
reader = progress.NewReaderWithOffset(rc, o.progress, completed)
if progressChan != nil {
reader = progress.NewReaderWithOffset(rc, progressChan, completed)
}
Comment thread
MetsysEht marked this conversation as resolved.

if _, err := io.Copy(cw, reader); err != nil {
cw.Close()
rc.Close()
closeProgress(o.progress)
closeProgress(progressChan)
return fmt.Errorf("writing layer: %w", err)
}

if err := cw.Commit(o.ctx, size, desc.Digest); err != nil {
cw.Close()
rc.Close()
if !errdefs.IsAlreadyExists(err) && !strings.Contains(err.Error(), "already exists") {
closeProgress(o.progress)
closeProgress(progressChan)
return fmt.Errorf("committing layer: %w", err)
}
// If it already exists, we still want to update progress
completed += size
if o.progress != nil {
o.progress <- oci.Update{
if progressChan != nil {
progressChan <- oci.Update{
Complete: completed,
Total: totalSize,
Total: size,
}
}
Comment thread
MetsysEht marked this conversation as resolved.
} else {
// Successfully committed, update progress
completed += size
if o.progress != nil {
o.progress <- oci.Update{
if progressChan != nil {
progressChan <- oci.Update{
Complete: completed,
Total: totalSize,
Total: size,
}
}
}
closeProgress(progressChan)
Comment thread
MetsysEht marked this conversation as resolved.
cw.Close()
rc.Close()
}
Expand All @@ -834,20 +861,17 @@ func Write(ref reference.Reference, img oci.Image, opts ...Option) error {
cw, err := pusher.Push(o.ctx, configDesc)
if err != nil {
if !errdefs.IsAlreadyExists(err) && !strings.Contains(err.Error(), "already exists") {
closeProgress(o.progress)
return fmt.Errorf("pushing config: %w", err)
}
// If it already exists, we don't have a writer to close, just continue
} else {
if _, err := cw.Write(rawConfig); err != nil {
cw.Close()
closeProgress(o.progress)
return fmt.Errorf("writing config: %w", err)
}
if err := cw.Commit(o.ctx, int64(len(rawConfig)), configDesc.Digest); err != nil {
cw.Close()
if !errdefs.IsAlreadyExists(err) && !strings.Contains(err.Error(), "already exists") {
closeProgress(o.progress)
return fmt.Errorf("committing config: %w", err)
}
}
Expand All @@ -857,19 +881,16 @@ func Write(ref reference.Reference, img oci.Image, opts ...Option) error {
// Push manifest
rawManifest, err := img.RawManifest()
if err != nil {
closeProgress(o.progress)
return fmt.Errorf("getting manifest: %w", err)
}

manifest, err := img.Manifest()
if err != nil {
closeProgress(o.progress)
return fmt.Errorf("getting manifest object: %w", err)
}

manifestDigest, err := img.Digest()
if err != nil {
closeProgress(o.progress)
return fmt.Errorf("getting manifest digest: %w", err)
}

Expand All @@ -882,24 +903,18 @@ func Write(ref reference.Reference, img oci.Image, opts ...Option) error {
cw, err = pusher.Push(o.ctx, manifestDesc)
if err != nil {
if !errdefs.IsAlreadyExists(err) && !strings.Contains(err.Error(), "already exists") {
closeProgress(o.progress)
return fmt.Errorf("pushing manifest: %w", err)
}
// If it already exists, we don't have a writer to close, just continue
// If it already exists, we still want to close progress and return success
closeProgress(o.progress)
return nil
}

if _, err := cw.Write(rawManifest); err != nil {
cw.Close()
closeProgress(o.progress)
return fmt.Errorf("writing manifest: %w", err)
}

if err := cw.Commit(o.ctx, int64(len(rawManifest)), manifestDesc.Digest); err != nil {
cw.Close()
closeProgress(o.progress)
if !errdefs.IsAlreadyExists(err) && !strings.Contains(err.Error(), "already exists") {
return fmt.Errorf("committing manifest: %w", err)
}
Expand All @@ -908,9 +923,6 @@ func Write(ref reference.Reference, img oci.Image, opts ...Option) error {
}
cw.Close()

// Close progress channel to signal completion
closeProgress(o.progress)

return nil
}

Expand Down
9 changes: 4 additions & 5 deletions pkg/distribution/registry/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"strings"
"sync"

"github.com/docker/model-runner/pkg/distribution/internal/progress"
"github.com/docker/model-runner/pkg/distribution/oci"
"github.com/docker/model-runner/pkg/distribution/oci/authn"
"github.com/docker/model-runner/pkg/distribution/oci/reference"
Expand Down Expand Up @@ -264,15 +263,15 @@ func (t *Target) Write(ctx context.Context, model types.ModelArtifact, progressW
}
imageSize += size
}
pr := progress.NewProgressReporter(progressWriter, progress.PushMsg, imageSize, nil)
defer pr.Wait()
//pr := progress.NewProgressReporter(progressWriter, progress.PushMsg, imageSize, nil)
//defer pr.Wait()

// Set up authentication options
authOpts := []remote.Option{
remote.WithContext(ctx),
remote.WithTransport(t.transport),
remote.WithUserAgent(t.userAgent),
remote.WithProgress(pr.Updates()),
//remote.WithProgress(pr.Updates()),
remote.WithPlainHTTP(t.plainHTTP),
}

Expand All @@ -283,7 +282,7 @@ func (t *Target) Write(ctx context.Context, model types.ModelArtifact, progressW
authOpts = append(authOpts, remote.WithAuthFromKeychain(t.keychain))
}

if err := remote.Write(t.reference, model, authOpts...); err != nil {
if err := remote.Write(t.reference, model, progressWriter, authOpts...); err != nil {
return fmt.Errorf("write to registry %q: %w", t.reference.String(), err)
}
return nil
Expand Down