diff --git a/plugin.go b/plugin.go index 1196e0f..e31b221 100644 --- a/plugin.go +++ b/plugin.go @@ -39,6 +39,8 @@ var ( ErrPluginNotFound = errors.New("plugin: not found") // ErrPluginMultipleInstances is used when a plugin is expected a single instance but has multiple ErrPluginMultipleInstances = errors.New("plugin: multiple instances") + // ErrPluginCircularDependency is used when the graph detect a circular plugin dependency + ErrPluginCircularDependency = errors.New("plugin: dependency loop detected") // ErrInvalidRequires will be thrown if the requirements for a plugin are // defined in an invalid manner. @@ -110,36 +112,43 @@ type Registry []*Registration // Graph computes the ordered list of registrations based on their dependencies, // filtering out any plugins which match the provided filter. func (registry Registry) Graph(filter DisableFilter) []Registration { - disabled := map[*Registration]bool{} - for _, r := range registry { - if filter(r) { - disabled[r] = true + handled := make(map[*Registration]struct{}, len(registry)) + if filter != nil { + for _, r := range registry { + if filter(r) { + handled[r] = struct{}{} + } } } - ordered := make([]Registration, 0, len(registry)-len(disabled)) - added := map[*Registration]bool{} + ordered := make([]Registration, 0, len(registry)-len(handled)) + stack := make([]*Registration, 0, cap(ordered)) for _, r := range registry { - if disabled[r] { + if _, ok := handled[r]; ok { continue } - children(r, registry, added, disabled, &ordered) - if !added[r] { - ordered = append(ordered, *r) - added[r] = true - } + children(append(stack, r), registry, handled, &ordered) + handled[r] = struct{}{} + ordered = append(ordered, *r) } return ordered } -func children(reg *Registration, registry []*Registration, added, disabled map[*Registration]bool, ordered *[]Registration) { +func children(stack []*Registration, registry []*Registration, handled map[*Registration]struct{}, ordered *[]Registration) { + reg := stack[len(stack)-1] for _, t := range reg.Requires { for _, r := range registry { - if (t == "*" || r.Type == t) && r != reg && !disabled[r] { - children(r, registry, added, disabled, ordered) - if !added[r] { + if (t == "*" || r.Type == t) && r != reg { + if _, ok := handled[r]; !ok { + // Ensure not in current stack + for _, p := range stack[:len(stack)-1] { + if p == r { + panic(fmt.Errorf("circular plugin dependency at %s: %w", r.URI(), ErrPluginCircularDependency)) + } + } + children(append(stack, r), registry, handled, ordered) + handled[r] = struct{}{} *ordered = append(*ordered, *r) - added[r] = true } } } @@ -160,7 +169,7 @@ func (registry Registry) Register(r *Registration) Registry { } for _, requires := range r.Requires { - if requires == "*" && len(r.Requires) != 1 { + if (requires == "*" && len(r.Requires) != 1) || requires == r.Type { panic(ErrInvalidRequires) } } diff --git a/plugin_test.go b/plugin_test.go index 671149c..c1436ba 100644 --- a/plugin_test.go +++ b/plugin_test.go @@ -44,7 +44,6 @@ const ( ) func testRegistry() Registry { - var register Registry return register.Register(&Registration{ Type: TaskMonitorPlugin, @@ -517,6 +516,158 @@ func testPlugin(t Type, id string, i interface{}, err error) *Plugin { } } +func TestRequiresAll(t *testing.T) { + var register Registry + register = register.Register(&Registration{ + Type: InternalPlugin, + ID: "system", + }).Register(&Registration{ + Type: ServicePlugin, + ID: "introspection", + Requires: []Type{ + "*", + }, + }).Register(&Registration{ + Type: ServicePlugin, + ID: "task", + Requires: []Type{ + InternalPlugin, + }, + }).Register(&Registration{ + Type: ServicePlugin, + ID: "version", + }) + ordered := register.Graph(mockPluginFilter) + expectedURI := []string{ + "io.containerd.internal.v1.system", + "io.containerd.service.v1.task", + "io.containerd.service.v1.version", + "io.containerd.service.v1.introspection", + } + cmpOrdered(t, ordered, expectedURI) +} + +func TestRegisterErrors(t *testing.T) { + + for _, tc := range []struct { + name string + expected error + register func(Registry) Registry + }{ + { + name: "duplicate", + expected: ErrIDRegistered, + register: func(r Registry) Registry { + return r.Register(&Registration{ + Type: TaskMonitorPlugin, + ID: "cgroups", + }).Register(&Registration{ + Type: TaskMonitorPlugin, + ID: "cgroups", + }) + }, + }, + { + name: "circular", + expected: ErrPluginCircularDependency, + register: func(r Registry) Registry { + // Circular dependencies should not loop but order will be based on registration order + return r.Register(&Registration{ + Type: InternalPlugin, + ID: "p1", + Requires: []Type{ + RuntimePlugin, + }, + }).Register(&Registration{ + Type: RuntimePlugin, + ID: "p2", + Requires: []Type{ + InternalPlugin, + }, + }).Register(&Registration{ + Type: InternalPlugin, + ID: "p3", + }) + }, + }, + { + name: "self", + expected: ErrInvalidRequires, + register: func(r Registry) Registry { + // Circular dependencies should not loop but order will be based on registration order + return r.Register(&Registration{ + Type: InternalPlugin, + ID: "p1", + Requires: []Type{ + InternalPlugin, + }, + }) + }, + }, + { + name: "no-type", + expected: ErrNoType, + register: func(r Registry) Registry { + // Circular dependencies should not loop but order will be based on registration order + return r.Register(&Registration{ + Type: "", + ID: "p1", + }) + }, + }, + { + name: "no-ID", + expected: ErrNoPluginID, + register: func(r Registry) Registry { + // Circular dependencies should not loop but order will be based on registration order + return r.Register(&Registration{ + Type: InternalPlugin, + ID: "", + }) + }, + }, + { + name: "bad-requires-all", + expected: ErrInvalidRequires, + register: func(r Registry) Registry { + // Circular dependencies should not loop but order will be based on registration order + return r.Register(&Registration{ + Type: InternalPlugin, + ID: "p1", + Requires: []Type{ + "*", + InternalPlugin, + }, + }) + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + var ( + r Registry + panicAny any + ) + func() { + defer func() { + panicAny = recover() + }() + + tc.register(r).Graph(mockPluginFilter) + }() + if panicAny == nil { + t.Fatalf("expected panic with error %v", tc.expected) + } + err, ok := panicAny.(error) + if !ok { + t.Fatalf("expected panic: %v, expected error %v", panicAny, tc.expected) + } + if !errors.Is(err, tc.expected) { + t.Fatalf("unexpected error type: %v, expected %v", panicAny, tc.expected) + } + }) + } +} + func BenchmarkGraph(b *testing.B) { register := testRegistry() b.ResetTimer()