1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package unified
8
9import (
10	"context"
11	"fmt"
12
13	"go.mongodb.org/mongo-driver/bson"
14	"go.mongodb.org/mongo-driver/mongo"
15	"go.mongodb.org/mongo-driver/mongo/options"
16)
17
18func executeAbortTransaction(ctx context.Context, operation *Operation) (*OperationResult, error) {
19	sess, err := Entities(ctx).Session(operation.Object)
20	if err != nil {
21		return nil, err
22	}
23
24	// AbortTransaction takes no options, so the arguments doc must be nil or empty.
25	elems, _ := operation.Arguments.Elements()
26	if len(elems) > 0 {
27		return nil, fmt.Errorf("unrecognized abortTransaction options %v", operation.Arguments)
28	}
29
30	return NewErrorResult(sess.AbortTransaction(ctx)), nil
31}
32
33func executeEndSession(ctx context.Context, operation *Operation) error {
34	sess, err := Entities(ctx).Session(operation.Object)
35	if err != nil {
36		return err
37	}
38
39	// EnsSession takes no options, so the arguments doc must be nil or empty.
40	elems, _ := operation.Arguments.Elements()
41	if len(elems) > 0 {
42		return fmt.Errorf("unrecognized endSession options %v", operation.Arguments)
43	}
44
45	sess.EndSession(ctx)
46	return nil
47}
48
49func executeCommitTransaction(ctx context.Context, operation *Operation) (*OperationResult, error) {
50	sess, err := Entities(ctx).Session(operation.Object)
51	if err != nil {
52		return nil, err
53	}
54
55	// CommitTransaction takes no options, so the arguments doc must be nil or empty.
56	elems, _ := operation.Arguments.Elements()
57	if len(elems) > 0 {
58		return nil, fmt.Errorf("unrecognized commitTransaction options %v", operation.Arguments)
59	}
60
61	return NewErrorResult(sess.CommitTransaction(ctx)), nil
62}
63
64func executeStartTransaction(ctx context.Context, operation *Operation) (*OperationResult, error) {
65	sess, err := Entities(ctx).Session(operation.Object)
66	if err != nil {
67		return nil, err
68	}
69
70	opts := options.Transaction()
71	if operation.Arguments != nil {
72		var temp TransactionOptions
73		if err := bson.Unmarshal(operation.Arguments, &temp); err != nil {
74			return nil, fmt.Errorf("error unmarshalling arguments to TransactionOptions: %v", err)
75		}
76
77		opts = temp.TransactionOptions
78	}
79
80	return NewErrorResult(sess.StartTransaction(opts)), nil
81}
82
83func executeWithTransaction(ctx context.Context, operation *Operation) error {
84	sess, err := Entities(ctx).Session(operation.Object)
85	if err != nil {
86		return err
87	}
88
89	// Process the "callback" argument. This is an array of Operation objects, each of which should be executed inside
90	// the transaction.
91	callback, err := operation.Arguments.LookupErr("callback")
92	if err != nil {
93		return newMissingArgumentError("callback")
94	}
95	var operations []*Operation
96	if err := callback.Unmarshal(&operations); err != nil {
97		return fmt.Errorf("error transforming callback option to slice of operations: %v", err)
98	}
99
100	// Remove the "callback" field and process the other options.
101	var temp TransactionOptions
102	if err := bson.Unmarshal(RemoveFieldsFromDocument(operation.Arguments, "callback"), &temp); err != nil {
103		return fmt.Errorf("error unmarshalling arguments to TransactionOptions: %v", err)
104	}
105
106	_, err = sess.WithTransaction(ctx, func(sessCtx mongo.SessionContext) (interface{}, error) {
107		for idx, op := range operations {
108			if err := op.Execute(ctx); err != nil {
109				return nil, fmt.Errorf("error executing operation %q at index %d: %v", op.Name, idx, err)
110			}
111		}
112		return nil, nil
113	}, temp.TransactionOptions)
114	return err
115}
116