Skip to content

Commit 8593abc

Browse files
committed
fix addToSet and push to support each op
1 parent 997a713 commit 8593abc

4 files changed

Lines changed: 54 additions & 6 deletions

File tree

pkg/mongoproxy/plugins/schema/schema.go

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@ import (
66
"io/ioutil"
77
"log"
88
"path"
9+
"strings"
910
"sync/atomic"
1011

12+
"go.mongodb.org/mongo-driver/bson"
13+
"gopkg.in/fsnotify.v1"
14+
1115
"github.com/cespare/xxhash/v2"
1216
"github.com/prometheus/client_golang/prometheus"
1317
"github.com/prometheus/client_golang/prometheus/promauto"
1418
"github.com/sirupsen/logrus"
15-
"go.mongodb.org/mongo-driver/bson"
16-
"gopkg.in/fsnotify.v1"
1719

1820
"github.com/wish/mongoproxy/pkg/bsonutil"
1921
"github.com/wish/mongoproxy/pkg/command"
@@ -114,12 +116,20 @@ func (p *SchemaPlugin) Configure(d bson.D) error {
114116
if err := p.LoadSchema(); err != nil {
115117
return err
116118
}
119+
// skip watcher for unit test
120+
if strings.HasPrefix(p.conf.SchemaPath, "example.json") {
121+
return nil
122+
}
123+
117124
// start watch
118125
watcher, err := fsnotify.NewWatcher()
119126
if err != nil {
120127
log.Fatal(err)
121128
}
122129

130+
defer watcher.Close()
131+
done := make(chan bool)
132+
123133
go func() {
124134
for {
125135
select {
@@ -145,7 +155,7 @@ func (p *SchemaPlugin) Configure(d bson.D) error {
145155
if err := watcher.Add(path.Dir(p.conf.SchemaPath)); err != nil {
146156
return err
147157
}
148-
158+
<-done
149159
return nil
150160
}
151161

@@ -168,7 +178,7 @@ func (p *SchemaPlugin) Process(ctx context.Context, r *plugins.Request, next plu
168178
case *command.FindAndModify:
169179
if len(cmd.Update) > 0 {
170180
schema := p.GetSchema()
171-
logrus.Infof("command findAndModify: %v", cmd.Update)
181+
logrus.Debugf("command findAndModify: %v", cmd.Update)
172182
if err := schema.ValidateUpdate(ctx, cmd.Database, cmd.Collection, cmd.Update, bsonutil.GetBoolDefault(cmd.Upsert, false)); err != nil {
173183
schemaDeny.WithLabelValues(cmd.Database, cmd.Collection, r.CommandName).Inc()
174184
if !p.conf.EnforceSchemaLogOnly {
@@ -182,7 +192,7 @@ func (p *SchemaPlugin) Process(ctx context.Context, r *plugins.Request, next plu
182192
case *command.Update:
183193
schema := p.GetSchema()
184194
for _, updateDoc := range cmd.Updates {
185-
logrus.Infof("print command Update: %v", updateDoc)
195+
logrus.Debugf("print command Update: %v", updateDoc)
186196
if err := schema.ValidateUpdate(ctx, cmd.Database, cmd.Collection, updateDoc.U, bsonutil.GetBoolDefault(updateDoc.Upsert, false)); err != nil {
187197
schemaDeny.WithLabelValues(cmd.Database, cmd.Collection, r.CommandName).Inc()
188198
if !p.conf.EnforceSchemaLogOnly {

pkg/mongoproxy/plugins/schema/type_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,10 @@ var (
199199
// push extra field
200200
{DB: "testdb", Collection: "requireonlysuba", In: bson.D{{"$push", bson.D{{"doc.a", "name"}, {"doc.b", 1}}}}, Err: true},
201201
{DB: "testdb", Collection: "requireonlysuba", In: bson.D{{"$push", bson.D{{"a", "name"}, {"doc.b", 1}}}}, Err: true},
202+
//test with each
203+
{DB: "testdb", Collection: "nonrequire", In: bson.D{{"$push", bson.D{{"luckynumbers", bson.D{{"$each", bson.A{1, 2, 3}}}}}}}},
204+
//test with each
205+
{DB: "testdb", Collection: "nonrequire", In: bson.D{{"$push", bson.D{{"luckynumbers", bson.E{"$each", bson.A{1, 2, 3}}}}}}},
202206

203207
//
204208
// pull tests
@@ -337,6 +341,10 @@ var (
337341
// addToSet extra field
338342
{DB: "testdb", Collection: "requireonlysuba", In: bson.D{{"$addToSet", bson.D{{"doc.a", "name"}, {"doc.b", 1}}}}, Err: true},
339343
{DB: "testdb", Collection: "requireonlysuba", In: bson.D{{"$addToSet", bson.D{{"a", "name"}, {"doc.b", 1}}}}, Err: true},
344+
//test with each
345+
{DB: "testdb", Collection: "nonrequire", In: bson.D{{"$addToSet", bson.D{{"luckynumbers", bson.D{{"$each", bson.A{1, 2, 3}}}}}}}},
346+
//test with each
347+
{DB: "testdb", Collection: "nonrequire", In: bson.D{{"$addToSet", bson.D{{"luckynumbers", bson.E{"$each", bson.A{1, 2, 3}}}}}}},
340348

341349
//
342350
// rename tests

pkg/mongoproxy/plugins/schema/types.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ func (c *Collection) ValidateUpdate(ctx context.Context, obj bson.D, upsert bool
267267
}
268268
case "$rename":
269269
renameFields = e.Value.(bson.D).Map()
270-
case "$set", "$pull", "$push", "$addToSet", "$pullAll":
270+
case "$set", "$pull", "$pullAll":
271271
if setFields == nil {
272272
setFields = Mapify(e.Value.(bson.D))
273273
} else {
@@ -276,6 +276,12 @@ func (c *Collection) ValidateUpdate(ctx context.Context, obj bson.D, upsert bool
276276
setFields[item.Key] = item.Value
277277
}
278278
}
279+
case "$addToSet", "$push":
280+
if setFields == nil {
281+
setFields = make(bson.M, len(e.Value.(bson.D)))
282+
}
283+
setFields = MapifyWithOp(e.Value.(bson.D), setFields)
284+
279285
case "$setOnInsert":
280286
insertFields = Mapify(e.Value.(bson.D))
281287
case "$unset":

pkg/mongoproxy/plugins/schema/util.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ import (
66
"reflect"
77
"regexp"
88

9+
"go.mongodb.org/mongo-driver/bson/primitive"
10+
11+
"github.com/sirupsen/logrus"
912
"go.mongodb.org/mongo-driver/bson"
1013
)
1114

@@ -141,6 +144,27 @@ func Mapify(d bson.D) bson.M {
141144
return m
142145
}
143146

147+
// Map creates a map from the elements of the D with operator
148+
// It makes additional process for arrays
149+
func MapifyWithOp(d bson.D, m bson.M) bson.M {
150+
for _, e := range d {
151+
e := processArray(e)
152+
if _, ok := e.Value.(primitive.D); ok {
153+
itemValueSet := e.Value.(bson.D).Map()
154+
if val, ok := itemValueSet["$each"]; ok {
155+
m[e.Key] = val
156+
continue
157+
}
158+
} else if _, ok := e.Value.(primitive.D); ok && e.Value.(bson.E).Key == "$each" {
159+
m[e.Key] = e.Value.(bson.E).Value
160+
continue
161+
}
162+
m[e.Key] = e.Value
163+
logrus.Debugf("Add %s type element to set", fmt.Sprint(reflect.TypeOf(e.Value)))
164+
}
165+
return m
166+
}
167+
144168
// looping and process elements in object
145169
func handleObj(obj bson.D, m bson.M) bson.M {
146170
for _, e := range obj {

0 commit comments

Comments
 (0)