Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 27 additions & 18 deletions plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ var (
ErrPluginNotFound = errors.New("plugin: not found")
// ErrPluginMultipleInstances is used when a plugin is expected a single instance but has multiple
ErrPluginMultipleInstances = errors.New("plugin: multiple instances")
// ErrPluginCircularDependency is used when the graph detect a circular plugin dependency
ErrPluginCircularDependency = errors.New("plugin: dependency loop detected")

// ErrInvalidRequires will be thrown if the requirements for a plugin are
// defined in an invalid manner.
Expand Down Expand Up @@ -110,36 +112,43 @@ type Registry []*Registration
// Graph computes the ordered list of registrations based on their dependencies,
// filtering out any plugins which match the provided filter.
func (registry Registry) Graph(filter DisableFilter) []Registration {
disabled := map[*Registration]bool{}
for _, r := range registry {
if filter(r) {
disabled[r] = true
handled := make(map[*Registration]struct{}, len(registry))
if filter != nil {
for _, r := range registry {
if filter(r) {
handled[r] = struct{}{}
}
}
}

ordered := make([]Registration, 0, len(registry)-len(disabled))
added := map[*Registration]bool{}
ordered := make([]Registration, 0, len(registry)-len(handled))
stack := make([]*Registration, 0, cap(ordered))
for _, r := range registry {
if disabled[r] {
if _, ok := handled[r]; ok {
continue
}
children(r, registry, added, disabled, &ordered)
if !added[r] {
ordered = append(ordered, *r)
added[r] = true
}
children(append(stack, r), registry, handled, &ordered)
handled[r] = struct{}{}
ordered = append(ordered, *r)
}
return ordered
}

func children(reg *Registration, registry []*Registration, added, disabled map[*Registration]bool, ordered *[]Registration) {
func children(stack []*Registration, registry []*Registration, handled map[*Registration]struct{}, ordered *[]Registration) {
reg := stack[len(stack)-1]
for _, t := range reg.Requires {
for _, r := range registry {
if (t == "*" || r.Type == t) && r != reg && !disabled[r] {
children(r, registry, added, disabled, ordered)
if !added[r] {
if (t == "*" || r.Type == t) && r != reg {
if _, ok := handled[r]; !ok {
// Ensure not in current stack
for _, p := range stack[:len(stack)-1] {
if p == r {
panic(fmt.Errorf("circular plugin dependency at %s: %w", r.URI(), ErrPluginCircularDependency))
}
}
children(append(stack, r), registry, handled, ordered)
handled[r] = struct{}{}
*ordered = append(*ordered, *r)
added[r] = true
}
}
}
Expand All @@ -160,7 +169,7 @@ func (registry Registry) Register(r *Registration) Registry {
}

for _, requires := range r.Requires {
if requires == "*" && len(r.Requires) != 1 {
if (requires == "*" && len(r.Requires) != 1) || requires == r.Type {
panic(ErrInvalidRequires)
}
}
Expand Down
153 changes: 152 additions & 1 deletion plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ const (
)

func testRegistry() Registry {

var register Registry
return register.Register(&Registration{
Type: TaskMonitorPlugin,
Expand Down Expand Up @@ -517,6 +516,158 @@ func testPlugin(t Type, id string, i interface{}, err error) *Plugin {
}
}

func TestRequiresAll(t *testing.T) {
var register Registry
register = register.Register(&Registration{
Type: InternalPlugin,
ID: "system",
}).Register(&Registration{
Type: ServicePlugin,
ID: "introspection",
Requires: []Type{
"*",
},
}).Register(&Registration{
Type: ServicePlugin,
ID: "task",
Requires: []Type{
InternalPlugin,
},
}).Register(&Registration{
Type: ServicePlugin,
ID: "version",
})
ordered := register.Graph(mockPluginFilter)
expectedURI := []string{
"io.containerd.internal.v1.system",
"io.containerd.service.v1.task",
"io.containerd.service.v1.version",
"io.containerd.service.v1.introspection",
}
cmpOrdered(t, ordered, expectedURI)
}

func TestRegisterErrors(t *testing.T) {

for _, tc := range []struct {
name string
expected error
register func(Registry) Registry
}{
{
name: "duplicate",
expected: ErrIDRegistered,
register: func(r Registry) Registry {
return r.Register(&Registration{
Type: TaskMonitorPlugin,
ID: "cgroups",
}).Register(&Registration{
Type: TaskMonitorPlugin,
ID: "cgroups",
})
},
},
{
name: "circular",
expected: ErrPluginCircularDependency,
register: func(r Registry) Registry {
// Circular dependencies should not loop but order will be based on registration order
return r.Register(&Registration{
Type: InternalPlugin,
ID: "p1",
Requires: []Type{
RuntimePlugin,
},
}).Register(&Registration{
Type: RuntimePlugin,
ID: "p2",
Requires: []Type{
InternalPlugin,
},
}).Register(&Registration{
Type: InternalPlugin,
ID: "p3",
})
},
},
{
name: "self",
expected: ErrInvalidRequires,
register: func(r Registry) Registry {
// Circular dependencies should not loop but order will be based on registration order
return r.Register(&Registration{
Type: InternalPlugin,
ID: "p1",
Requires: []Type{
InternalPlugin,
},
})
},
},
{
name: "no-type",
expected: ErrNoType,
register: func(r Registry) Registry {
// Circular dependencies should not loop but order will be based on registration order
return r.Register(&Registration{
Type: "",
ID: "p1",
})
},
},
{
name: "no-ID",
expected: ErrNoPluginID,
register: func(r Registry) Registry {
// Circular dependencies should not loop but order will be based on registration order
return r.Register(&Registration{
Type: InternalPlugin,
ID: "",
})
},
},
{
name: "bad-requires-all",
expected: ErrInvalidRequires,
register: func(r Registry) Registry {
// Circular dependencies should not loop but order will be based on registration order
return r.Register(&Registration{
Type: InternalPlugin,
ID: "p1",
Requires: []Type{
"*",
InternalPlugin,
},
})
},
},
} {
t.Run(tc.name, func(t *testing.T) {
var (
r Registry
panicAny any
)
func() {
defer func() {
panicAny = recover()
}()

tc.register(r).Graph(mockPluginFilter)
}()
if panicAny == nil {
t.Fatalf("expected panic with error %v", tc.expected)
}
err, ok := panicAny.(error)
if !ok {
t.Fatalf("expected panic: %v, expected error %v", panicAny, tc.expected)
}
if !errors.Is(err, tc.expected) {
t.Fatalf("unexpected error type: %v, expected %v", panicAny, tc.expected)
}
})
}
}

func BenchmarkGraph(b *testing.B) {
register := testRegistry()
b.ResetTimer()
Expand Down
Loading