1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package unified
8
9import (
10	"context"
11	"fmt"
12
13	"go.mongodb.org/mongo-driver/bson"
14	"go.mongodb.org/mongo-driver/mongo"
15	"go.mongodb.org/mongo-driver/mongo/integration/mtest"
16	"go.mongodb.org/mongo-driver/mongo/options"
17)
18
19const (
20	errorInterrupted int32 = 11601
21)
22
23// TerminateOpenSessions executes a killAllSessions command to ensure that sesssions left open on the server by a test
24// do not cause future tests to hang.
25func TerminateOpenSessions(ctx context.Context) error {
26	if mtest.CompareServerVersions(mtest.ServerVersion(), "3.6") < 0 {
27		return nil
28	}
29
30	commandFn := func(ctx context.Context, client *mongo.Client) error {
31		cmd := bson.D{
32			{"killAllSessions", bson.A{}},
33		}
34
35		err := client.Database("admin").RunCommand(ctx, cmd).Err()
36		if ce, ok := err.(mongo.CommandError); ok && ce.Code == errorInterrupted {
37			// Workaround for SERVER-38335 on server versions below 4.2.
38			err = nil
39		}
40		return err
41	}
42
43	// For sharded clusters, this has to run against all mongos nodes. Otherwise, it can just against on the primary.
44	if mtest.ClusterTopologyKind() != mtest.Sharded {
45		return commandFn(ctx, mtest.GlobalClient())
46	}
47	return runAgainstAllMongoses(ctx, commandFn)
48}
49
50// PerformDistinctWorkaround executes a non-transactional "distinct" command against each mongos in a sharded cluster.
51func PerformDistinctWorkaround(ctx context.Context) error {
52	commandFn := func(ctx context.Context, client *mongo.Client) error {
53		for _, coll := range Entities(ctx).Collections() {
54			newColl := client.Database(coll.Database().Name()).Collection(coll.Name())
55			_, err := newColl.Distinct(ctx, "x", bson.D{})
56			if err != nil {
57				ns := fmt.Sprintf("%s.%s", coll.Database().Name(), coll.Name())
58				return fmt.Errorf("error running distinct for collection %q: %v", ns, err)
59			}
60		}
61
62		return nil
63	}
64
65	return runAgainstAllMongoses(ctx, commandFn)
66}
67
68func RunCommandOnHost(ctx context.Context, host string, commandFn func(context.Context, *mongo.Client) error) error {
69	clientOpts := options.Client().
70		ApplyURI(mtest.ClusterURI()).
71		SetHosts([]string{host})
72
73	client, err := mongo.Connect(ctx, clientOpts)
74	if err != nil {
75		return fmt.Errorf("error creating client to host %q: %v", host, err)
76	}
77	defer client.Disconnect(ctx)
78
79	return commandFn(ctx, client)
80}
81
82func runAgainstAllMongoses(ctx context.Context, commandFn func(context.Context, *mongo.Client) error) error {
83	for _, host := range mtest.ClusterConnString().Hosts {
84		if err := RunCommandOnHost(ctx, host, commandFn); err != nil {
85			return fmt.Errorf("error executing callback against host %q: %v", host, err)
86		}
87	}
88	return nil
89}
90