@@ -546,6 +546,9 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex
546546 case "NewSet" :
547547 pset , errs := oc .processNewSet (info , pkgPath , call , nil , varName )
548548 return pset , notePositionAll (exprPos , errs )
549+ case "Subtract" :
550+ pset , errs := oc .processSubtract (info , pkgPath , call , nil , varName )
551+ return pset , notePositionAll (exprPos , errs )
549552 case "Bind" :
550553 b , err := processBind (oc .fset , info , call )
551554 if err != nil {
@@ -590,6 +593,115 @@ func (oc *objectCache) processExpr(info *types.Info, pkgPath string, expr ast.Ex
590593 return nil , []error {notePosition (exprPos , errors .New ("unknown pattern" ))}
591594}
592595
596+ func (oc * objectCache ) filterType (set * ProviderSet , t types.Type ) []error {
597+ hasType := func (outs []types.Type ) bool {
598+ for _ , o := range outs {
599+ if types .Identical (o , t ) {
600+ return true
601+ }
602+ pt , ok := o .(* types.Pointer )
603+ if ok && types .Identical (pt .Elem (), t ) {
604+ return true
605+ }
606+ }
607+ return false
608+ }
609+
610+ providers := make ([]* Provider , 0 , len (set .Providers ))
611+ for _ , p := range set .Providers {
612+ if ! hasType (p .Out ) {
613+ providers = append (providers , p )
614+ }
615+ }
616+ set .Providers = providers
617+
618+ bindings := make ([]* IfaceBinding , 0 , len (set .Bindings ))
619+ for _ , i := range set .Bindings {
620+ if ! types .Identical (i .Iface , t ) {
621+ bindings = append (bindings , i )
622+ }
623+ }
624+ set .Bindings = bindings
625+
626+ values := make ([]* Value , 0 , len (set .Values ))
627+ for _ , v := range set .Values {
628+ if ! types .Identical (v .Out , t ) {
629+ values = append (values , v )
630+ }
631+ }
632+ set .Values = values
633+
634+ fields := make ([]* Field , 0 , len (set .Fields ))
635+ for _ , f := range set .Fields {
636+ if ! hasType (f .Out ) {
637+ fields = append (fields , f )
638+ }
639+ }
640+ set .Fields = fields
641+
642+ imports := make ([]* ProviderSet , 0 , len (set .Imports ))
643+ for _ , p := range set .Imports {
644+ clone := * p
645+ if errs := oc .filterType (& clone , t ); len (errs ) > 0 {
646+ return errs
647+ }
648+ imports = append (imports , & clone )
649+ }
650+ set .Imports = imports
651+
652+ var errs []error
653+ set .providerMap , set .srcMap , errs = buildProviderMap (oc .fset , oc .hasher , set )
654+ if len (errs ) > 0 {
655+ return errs
656+ }
657+ return nil
658+ }
659+
660+ func (oc * objectCache ) processSubtract (info * types.Info , pkgPath string , call * ast.CallExpr , args * InjectorArgs , varName string ) (interface {}, []error ) {
661+ // Assumes that call.Fun is wire.Subtract.
662+ if len (call .Args ) < 2 {
663+ return nil , []error {notePosition (oc .fset .Position (call .Pos ()),
664+ errors .New ("call to Subtract must specify types to be subtracted" ))}
665+ }
666+ firstArg , errs := oc .processExpr (info , pkgPath , call .Args [0 ], "" )
667+ if len (errs ) > 0 {
668+ return nil , errs
669+ }
670+ set , ok := firstArg .(* ProviderSet )
671+ if ! ok {
672+ return nil , []error {notePosition (oc .fset .Position (call .Pos ()),
673+ fmt .Errorf ("first argument to Subtract must be a Set" )),
674+ }
675+ }
676+ pset := & ProviderSet {
677+ Pos : call .Pos (),
678+ InjectorArgs : args ,
679+ PkgPath : pkgPath ,
680+ VarName : varName ,
681+ // Copy the other fields.
682+ Providers : set .Providers ,
683+ Bindings : set .Bindings ,
684+ Values : set .Values ,
685+ Fields : set .Fields ,
686+ Imports : set .Imports ,
687+ }
688+ ec := new (errorCollector )
689+ for _ , arg := range call .Args [1 :] {
690+ ptr , ok := info .TypeOf (arg ).(* types.Pointer )
691+ if ! ok {
692+ ec .add (notePosition (oc .fset .Position (arg .Pos ()),
693+ errors .New ("argument to Subtract must be a pointer" ),
694+ ))
695+ continue
696+ }
697+ ec .add (oc .filterType (pset , ptr .Elem ())... )
698+ }
699+ if len (ec .errors ) > 0 {
700+ return nil , ec .errors
701+ }
702+ return pset , nil
703+ }
704+
593705func (oc * objectCache ) processNewSet (info * types.Info , pkgPath string , call * ast.CallExpr , args * InjectorArgs , varName string ) (* ProviderSet , []error ) {
594706 // Assumes that call.Fun is wire.NewSet or wire.Build.
595707
@@ -1173,9 +1285,9 @@ func (pt ProvidedType) IsNil() bool {
11731285//
11741286// - For a function provider, this is the first return value type.
11751287// - For a struct provider, this is either the struct type or the pointer type
1176- // whose element type is the struct type.
1177- // - For a value, this is the type of the expression.
1178- // - For an argument, this is the type of the argument.
1288+ // whose element type is the struct type.
1289+ // - For a value, this is the type of the expression.
1290+ // - For an argument, this is the type of the argument.
11791291func (pt ProvidedType ) Type () types.Type {
11801292 return pt .t
11811293}
0 commit comments