diff --git a/collection_test.go b/collection_test.go index 7d911a1..c2235ec 100644 --- a/collection_test.go +++ b/collection_test.go @@ -1943,6 +1943,111 @@ func TestEstimatedDocumentCountMaxTimeMS(t *testing.T) { }) } +// TestCollectionCount covers db.coll.count() — the deprecated collection method +// still heavily used in mongosh scripts. Zero-arg routes to estimatedDocumentCount +// (preserving the fast metadata path); any-arg routes to countDocuments. +func TestCollectionCount(t *testing.T) { + testutil.RunOnAllDBs(t, func(t *testing.T, db testutil.TestDB) { + dbName := fmt.Sprintf("testdb_coll_count_%s", db.Name) + defer testutil.CleanupDatabase(t, db.Client, dbName) + + ctx := context.Background() + + coll := db.Client.Database(dbName).Collection("users") + _, err := coll.InsertMany(ctx, []any{ + bson.M{"name": "alice", "age": 30, "status": "active"}, + bson.M{"name": "bob", "age": 25, "status": "inactive"}, + bson.M{"name": "charlie", "age": 35, "status": "active"}, + bson.M{"name": "diana", "age": 28, "status": "active"}, + }) + require.NoError(t, err) + + gc := gomongo.NewClient(db.Client) + + // count() — zero args, fast path + result, err := gc.Execute(ctx, dbName, "db.users.count()") + require.NoError(t, err) + require.Equal(t, int64(4), result.Value[0].(int64)) + + // count({}) — explicit empty filter, accurate path + result, err = gc.Execute(ctx, dbName, "db.users.count({})") + require.NoError(t, err) + require.Equal(t, int64(4), result.Value[0].(int64)) + + // count(filter) + result, err = gc.Execute(ctx, dbName, `db.users.count({ status: "active" })`) + require.NoError(t, err) + require.Equal(t, int64(3), result.Value[0].(int64)) + + // count(filter) with comparison operator + result, err = gc.Execute(ctx, dbName, `db.users.count({ age: { $gte: 30 } })`) + require.NoError(t, err) + require.Equal(t, int64(2), result.Value[0].(int64)) + }) +} + +func TestCollectionCountWithOptions(t *testing.T) { + testutil.RunOnAllDBs(t, func(t *testing.T, db testutil.TestDB) { + dbName := fmt.Sprintf("testdb_coll_count_opts_%s", db.Name) + defer testutil.CleanupDatabase(t, db.Client, dbName) + + ctx := context.Background() + + coll := db.Client.Database(dbName).Collection("users") + _, err := coll.InsertMany(ctx, []any{ + bson.M{"name": "alice", "age": 30}, + bson.M{"name": "bob", "age": 25}, + bson.M{"name": "charlie", "age": 35}, + bson.M{"name": "diana", "age": 28}, + bson.M{"name": "eve", "age": 40}, + }) + require.NoError(t, err) + + gc := gomongo.NewClient(db.Client) + + // limit option + result, err := gc.Execute(ctx, dbName, `db.users.count({}, { limit: 3 })`) + require.NoError(t, err) + require.Equal(t, int64(3), result.Value[0].(int64)) + + // skip option + result, err = gc.Execute(ctx, dbName, `db.users.count({}, { skip: 2 })`) + require.NoError(t, err) + require.Equal(t, int64(3), result.Value[0].(int64)) + + // skip + limit + result, err = gc.Execute(ctx, dbName, `db.users.count({}, { skip: 1, limit: 2 })`) + require.NoError(t, err) + require.Equal(t, int64(2), result.Value[0].(int64)) + + // maxTimeMS option + result, err = gc.Execute(ctx, dbName, `db.users.count({}, { maxTimeMS: 5000 })`) + require.NoError(t, err) + require.Equal(t, int64(5), result.Value[0].(int64)) + }) +} + +func TestCollectionCountEmptyCollection(t *testing.T) { + testutil.RunOnAllDBs(t, func(t *testing.T, db testutil.TestDB) { + dbName := fmt.Sprintf("testdb_coll_count_empty_%s", db.Name) + defer testutil.CleanupDatabase(t, db.Client, dbName) + + ctx := context.Background() + + gc := gomongo.NewClient(db.Client) + + // Zero-arg count on missing collection + result, err := gc.Execute(ctx, dbName, "db.users.count()") + require.NoError(t, err) + require.Equal(t, int64(0), result.Value[0].(int64)) + + // count({}) on missing collection + result, err = gc.Execute(ctx, dbName, "db.users.count({})") + require.NoError(t, err) + require.Equal(t, int64(0), result.Value[0].(int64)) + }) +} + func TestDistinct(t *testing.T) { testutil.RunOnAllDBs(t, func(t *testing.T, db testutil.TestDB) { dbName := fmt.Sprintf("testdb_distinct_%s", db.Name) @@ -2116,22 +2221,93 @@ func TestDistinctMaxTimeMS(t *testing.T) { }) } -func TestCursorCountUnsupported(t *testing.T) { +func TestCursorCount(t *testing.T) { testutil.RunOnAllDBs(t, func(t *testing.T, db testutil.TestDB) { dbName := fmt.Sprintf("testdb_cursor_count_%s", db.Name) defer testutil.CleanupDatabase(t, db.Client, dbName) ctx := context.Background() + coll := db.Client.Database(dbName).Collection("users") + _, err := coll.InsertMany(ctx, []any{ + bson.M{"name": "Alice", "status": "active"}, + bson.M{"name": "Bob", "status": "inactive"}, + bson.M{"name": "Charlie", "status": "active"}, + bson.M{"name": "Diana", "status": "active"}, + }) + require.NoError(t, err) + + gc := gomongo.NewClient(db.Client) + + // find().count() — empty filter + result, err := gc.Execute(ctx, dbName, "db.users.find().count()") + require.NoError(t, err) + require.Equal(t, int64(4), result.Value[0].(int64)) + + // find(filter).count() — accumulates filter from find() + result, err = gc.Execute(ctx, dbName, `db.users.find({ status: "active" }).count()`) + require.NoError(t, err) + require.Equal(t, int64(3), result.Value[0].(int64)) + + // find().skip().limit().count() — modern mongosh always honors skip/limit on cursor.count + result, err = gc.Execute(ctx, dbName, "db.users.find().skip(1).limit(2).count()") + require.NoError(t, err) + require.Equal(t, int64(2), result.Value[0].(int64)) + + // itcount and size are aliases for count on a find cursor + result, err = gc.Execute(ctx, dbName, `db.users.find({ status: "active" }).itcount()`) + require.NoError(t, err) + require.Equal(t, int64(3), result.Value[0].(int64)) + + result, err = gc.Execute(ctx, dbName, `db.users.find({ status: "active" }).size()`) + require.NoError(t, err) + require.Equal(t, int64(3), result.Value[0].(int64)) + }) +} + +func TestAggregateItcountUnsupported(t *testing.T) { + testutil.RunOnAllDBs(t, func(t *testing.T, db testutil.TestDB) { + dbName := fmt.Sprintf("testdb_agg_itcount_%s", db.Name) + defer testutil.CleanupDatabase(t, db.Client, dbName) + + ctx := context.Background() + gc := gomongo.NewClient(db.Client) - // cursor.count() is not in the planned registry, should return UnsupportedOperationError - _, err := gc.Execute(ctx, dbName, "db.users.find().count()") + // aggregate cursors expose itcount() in mongosh but require pipeline rewriting + // ($count stage). Out of scope until telemetry shows demand. + _, err := gc.Execute(ctx, dbName, "db.users.aggregate([{ $match: {} }]).itcount()") require.Error(t, err) var unsupportedErr *gomongo.UnsupportedOperationError require.ErrorAs(t, err, &unsupportedErr) - require.Equal(t, "count()", unsupportedErr.Operation) + require.Equal(t, "itcount()", unsupportedErr.Operation) + }) +} + +// TestCursorCountRejectsArgs guards against mongosh's legacy +// cursor.count(applySkipLimit) form silently dropping the boolean: the arg is +// a no-op in modern drivers, and silently dropping it would make incorrect +// caller assumptions hard to debug. +func TestCursorCountRejectsArgs(t *testing.T) { + testutil.RunOnAllDBs(t, func(t *testing.T, db testutil.TestDB) { + dbName := fmt.Sprintf("testdb_cursor_count_args_%s", db.Name) + defer testutil.CleanupDatabase(t, db.Client, dbName) + + ctx := context.Background() + + gc := gomongo.NewClient(db.Client) + + for _, stmt := range []string{ + "db.users.find().count(true)", + "db.users.find().count(false)", + "db.users.find().itcount(true)", + "db.users.find().size(1)", + } { + _, err := gc.Execute(ctx, dbName, stmt) + require.Error(t, err, "expected %q to be rejected", stmt) + require.Contains(t, err.Error(), "takes no arguments") + } }) } diff --git a/internal/translator/translate.go b/internal/translator/translate.go index 6d8b5e0..06f57e6 100644 --- a/internal/translator/translate.go +++ b/internal/translator/translate.go @@ -89,6 +89,19 @@ func translateCollectionStatement(op *Operation, stmt *ast.CollectionStatement) if err := extractEstimatedDocumentCountArgs(op, stmt.Args); err != nil { return nil, err } + case "count": + // db.coll.count() is deprecated in mongosh in favor of countDocuments / + // estimatedDocumentCount. We honor both halves of that recommendation: + // zero-arg count() routes to estimatedDocumentCount (preserves the fast + // metadata path), and any-arg count() routes to countDocuments. + if len(stmt.Args) == 0 { + op.OpType = types.OpEstimatedDocumentCount + } else { + op.OpType = types.OpCountDocuments + if err := extractCountDocumentsArgs(op, stmt.Args); err != nil { + return nil, err + } + } case "distinct": op.OpType = types.OpDistinct if err := extractDistinctArgs(op, stmt.Args); err != nil { @@ -239,6 +252,24 @@ func translateCursorMethod(op *Operation, cm ast.CursorMethod) error { return extractMin(op, cm.Args) case "pretty": return nil // no-op + case "count", "itcount", "size": + // mongosh's cursor.count() never iterates the cursor; it issues a + // separate count command server-side. We mirror that by retargeting + // the operation to CountDocuments with the accumulated + // filter+skip+limit+hint. Aggregate cursors also expose itcount() but + // require pipeline rewriting ($count stage); not yet supported. + if len(cm.Args) > 0 { + // mongosh historically accepted cursor.count(applySkipLimit), but + // the boolean is a no-op in modern drivers (skip/limit always + // apply). Reject rather than silently drop, since "no-op" is + // hard to debug. + return fmt.Errorf("%s() takes no arguments", cm.Method) + } + if op.OpType != types.OpFind { + return &UnsupportedOperationError{Operation: cm.Method + "()"} + } + op.OpType = types.OpCountDocuments + return nil default: return &UnsupportedOperationError{Operation: cm.Method + "()"} }