1/*
2   Copyright 2020 Docker Compose CLI authors
3
4   Licensed under the Apache License, Version 2.0 (the "License");
5   you may not use this file except in compliance with the License.
6   You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10   Unless required by applicable law or agreed to in writing, software
11   distributed under the License is distributed on an "AS IS" BASIS,
12   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   See the License for the specific language governing permissions and
14   limitations under the License.
15*/
16
17package compose
18
19import (
20	"context"
21	"fmt"
22	"io"
23
24	"github.com/compose-spec/compose-go/types"
25	"github.com/docker/cli/cli/streams"
26	"github.com/docker/compose/v2/pkg/api"
27	moby "github.com/docker/docker/api/types"
28	"github.com/docker/docker/api/types/container"
29	"github.com/docker/docker/pkg/ioutils"
30	"github.com/docker/docker/pkg/stdcopy"
31	"github.com/docker/docker/pkg/stringid"
32	"github.com/moby/term"
33)
34
35func (s *composeService) RunOneOffContainer(ctx context.Context, project *types.Project, opts api.RunOptions) (int, error) {
36	containerID, err := s.prepareRun(ctx, project, opts)
37	if err != nil {
38		return 0, err
39	}
40
41	if opts.Detach {
42		err := s.apiClient.ContainerStart(ctx, containerID, moby.ContainerStartOptions{})
43		if err != nil {
44			return 0, err
45		}
46		fmt.Fprintln(opts.Stdout, containerID)
47		return 0, nil
48	}
49
50	return s.runInteractive(ctx, containerID, opts)
51}
52
53func (s *composeService) runInteractive(ctx context.Context, containerID string, opts api.RunOptions) (int, error) {
54	r, err := s.getEscapeKeyProxy(opts.Stdin)
55	if err != nil {
56		return 0, err
57	}
58
59	stdin, stdout, err := s.getContainerStreams(ctx, containerID)
60	if err != nil {
61		return 0, err
62	}
63
64	in := streams.NewIn(opts.Stdin)
65	if in.IsTerminal() {
66		state, err := term.SetRawTerminal(in.FD())
67		if err != nil {
68			return 0, err
69		}
70		defer term.RestoreTerminal(in.FD(), state) //nolint:errcheck
71	}
72
73	outputDone := make(chan error)
74	inputDone := make(chan error)
75
76	go func() {
77		if opts.Tty {
78			_, err := io.Copy(opts.Stdout, stdout) //nolint:errcheck
79			outputDone <- err
80		} else {
81			_, err := stdcopy.StdCopy(opts.Stdout, opts.Stderr, stdout) //nolint:errcheck
82			outputDone <- err
83		}
84		stdout.Close() //nolint:errcheck
85	}()
86
87	go func() {
88		_, err := io.Copy(stdin, r)
89		inputDone <- err
90		stdin.Close() //nolint:errcheck
91	}()
92
93	err = s.apiClient.ContainerStart(ctx, containerID, moby.ContainerStartOptions{})
94	if err != nil {
95		return 0, err
96	}
97
98	s.monitorTTySize(ctx, containerID, s.apiClient.ContainerResize)
99
100	for {
101		select {
102		case err := <-outputDone:
103			if err != nil {
104				return 0, err
105			}
106			return s.terminateRun(ctx, containerID, opts)
107		case err := <-inputDone:
108			if _, ok := err.(term.EscapeError); ok {
109				return 0, nil
110			}
111			if err != nil {
112				return 0, err
113			}
114			// Wait for output to complete streaming
115		case <-ctx.Done():
116			return 0, ctx.Err()
117		}
118	}
119}
120
121func (s *composeService) terminateRun(ctx context.Context, containerID string, opts api.RunOptions) (exitCode int, err error) {
122	exitCh, errCh := s.apiClient.ContainerWait(ctx, containerID, container.WaitConditionNotRunning)
123	select {
124	case exit := <-exitCh:
125		exitCode = int(exit.StatusCode)
126	case err = <-errCh:
127		return
128	}
129	if opts.AutoRemove {
130		err = s.apiClient.ContainerRemove(ctx, containerID, moby.ContainerRemoveOptions{})
131	}
132	return
133}
134
135func (s *composeService) prepareRun(ctx context.Context, project *types.Project, opts api.RunOptions) (string, error) {
136	if err := prepareVolumes(project); err != nil { // all dependencies already checked, but might miss service img
137		return "", err
138	}
139	service, err := project.GetService(opts.Service)
140	if err != nil {
141		return "", err
142	}
143
144	applyRunOptions(project, &service, opts)
145
146	slug := stringid.GenerateRandomID()
147	if service.ContainerName == "" {
148		service.ContainerName = fmt.Sprintf("%s_%s_run_%s", project.Name, service.Name, stringid.TruncateID(slug))
149	}
150	service.Scale = 1
151	service.StdinOpen = true
152	service.Restart = ""
153	if service.Deploy != nil {
154		service.Deploy.RestartPolicy = nil
155	}
156	service.Labels = service.Labels.Add(api.SlugLabel, slug)
157	service.Labels = service.Labels.Add(api.OneoffLabel, "True")
158
159	if err := s.ensureImagesExists(ctx, project, false); err != nil { // all dependencies already checked, but might miss service img
160		return "", err
161	}
162	if !opts.NoDeps {
163		if err := s.waitDependencies(ctx, project, service); err != nil {
164			return "", err
165		}
166	}
167	created, err := s.createContainer(ctx, project, service, service.ContainerName, 1, opts.Detach && opts.AutoRemove, opts.UseNetworkAliases, true)
168	if err != nil {
169		return "", err
170	}
171	containerID := created.ID
172	return containerID, nil
173}
174
175func (s *composeService) getEscapeKeyProxy(r io.ReadCloser) (io.ReadCloser, error) {
176	var escapeKeys = []byte{16, 17}
177	if s.configFile.DetachKeys != "" {
178		customEscapeKeys, err := term.ToBytes(s.configFile.DetachKeys)
179		if err != nil {
180			return nil, err
181		}
182		escapeKeys = customEscapeKeys
183	}
184	return ioutils.NewReadCloserWrapper(term.NewEscapeProxy(r, escapeKeys), r.Close), nil
185}
186
187func applyRunOptions(project *types.Project, service *types.ServiceConfig, opts api.RunOptions) {
188	service.Tty = opts.Tty
189	service.StdinOpen = true
190	service.ContainerName = opts.Name
191
192	if len(opts.Command) > 0 {
193		service.Command = opts.Command
194	}
195	if len(opts.User) > 0 {
196		service.User = opts.User
197	}
198	if len(opts.WorkingDir) > 0 {
199		service.WorkingDir = opts.WorkingDir
200	}
201	if opts.Entrypoint != nil {
202		service.Entrypoint = opts.Entrypoint
203	}
204	if len(opts.Environment) > 0 {
205		env := types.NewMappingWithEquals(opts.Environment)
206		projectEnv := env.Resolve(func(s string) (string, bool) {
207			v, ok := project.Environment[s]
208			return v, ok
209		}).RemoveEmpty()
210		service.Environment.OverrideBy(projectEnv)
211	}
212	for k, v := range opts.Labels {
213		service.Labels = service.Labels.Add(k, v)
214	}
215}
216