1package codespace
2
3import (
4	"bytes"
5	"context"
6	"encoding/json"
7	"errors"
8	"fmt"
9	"net"
10	"strconv"
11	"strings"
12
13	"github.com/cli/cli/v2/internal/codespaces"
14	"github.com/cli/cli/v2/internal/codespaces/api"
15	"github.com/cli/cli/v2/pkg/cmdutil"
16	"github.com/cli/cli/v2/pkg/liveshare"
17	"github.com/cli/cli/v2/utils"
18	"github.com/muhammadmuzzammil1998/jsonc"
19	"github.com/spf13/cobra"
20	"golang.org/x/sync/errgroup"
21)
22
23// newPortsCmd returns a Cobra "ports" command that displays a table of available ports,
24// according to the specified flags.
25func newPortsCmd(app *App) *cobra.Command {
26	var codespace string
27	var exporter cmdutil.Exporter
28
29	portsCmd := &cobra.Command{
30		Use:   "ports",
31		Short: "List ports in a codespace",
32		Args:  noArgsConstraint,
33		RunE: func(cmd *cobra.Command, args []string) error {
34			return app.ListPorts(cmd.Context(), codespace, exporter)
35		},
36	}
37
38	portsCmd.PersistentFlags().StringVarP(&codespace, "codespace", "c", "", "Name of the codespace")
39	cmdutil.AddJSONFlags(portsCmd, &exporter, portFields)
40
41	portsCmd.AddCommand(newPortsForwardCmd(app))
42	portsCmd.AddCommand(newPortsVisibilityCmd(app))
43
44	return portsCmd
45}
46
47// ListPorts lists known ports in a codespace.
48func (a *App) ListPorts(ctx context.Context, codespaceName string, exporter cmdutil.Exporter) (err error) {
49	codespace, err := getOrChooseCodespace(ctx, a.apiClient, codespaceName)
50	if err != nil {
51		// TODO(josebalius): remove special handling of this error here and it other places
52		if err == errNoCodespaces {
53			return err
54		}
55		return fmt.Errorf("error choosing codespace: %w", err)
56	}
57
58	devContainerCh := getDevContainer(ctx, a.apiClient, codespace)
59
60	session, err := codespaces.ConnectToLiveshare(ctx, a, noopLogger(), a.apiClient, codespace)
61	if err != nil {
62		return fmt.Errorf("error connecting to codespace: %w", err)
63	}
64	defer safeClose(session, &err)
65
66	a.StartProgressIndicatorWithLabel("Fetching ports")
67	ports, err := session.GetSharedServers(ctx)
68	a.StopProgressIndicator()
69	if err != nil {
70		return fmt.Errorf("error getting ports of shared servers: %w", err)
71	}
72
73	devContainerResult := <-devContainerCh
74	if devContainerResult.err != nil {
75		// Warn about failure to read the devcontainer file. Not a codespace command error.
76		a.errLogger.Printf("Failed to get port names: %v", devContainerResult.err.Error())
77	}
78
79	portInfos := make([]*portInfo, len(ports))
80	for i, p := range ports {
81		portInfos[i] = &portInfo{
82			Port:         p,
83			codespace:    codespace,
84			devContainer: devContainerResult.devContainer,
85		}
86	}
87
88	if err := a.io.StartPager(); err != nil {
89		a.errLogger.Printf("error starting pager: %v", err)
90	}
91	defer a.io.StopPager()
92
93	if exporter != nil {
94		return exporter.Write(a.io, portInfos)
95	}
96
97	cs := a.io.ColorScheme()
98	tp := utils.NewTablePrinter(a.io)
99
100	if tp.IsTTY() {
101		tp.AddField("LABEL", nil, nil)
102		tp.AddField("PORT", nil, nil)
103		tp.AddField("VISIBILITY", nil, nil)
104		tp.AddField("BROWSE URL", nil, nil)
105		tp.EndRow()
106	}
107
108	for _, port := range portInfos {
109		tp.AddField(port.Label(), nil, nil)
110		tp.AddField(strconv.Itoa(port.SourcePort), nil, cs.Yellow)
111		tp.AddField(port.Privacy, nil, nil)
112		tp.AddField(port.BrowseURL(), nil, nil)
113		tp.EndRow()
114	}
115	return tp.Render()
116}
117
118type portInfo struct {
119	*liveshare.Port
120	codespace    *api.Codespace
121	devContainer *devContainer
122}
123
124func (pi *portInfo) BrowseURL() string {
125	return fmt.Sprintf("https://%s-%d.githubpreview.dev", pi.codespace.Name, pi.Port.SourcePort)
126}
127
128func (pi *portInfo) Label() string {
129	if pi.devContainer != nil {
130		portStr := strconv.Itoa(pi.Port.SourcePort)
131		if attributes, ok := pi.devContainer.PortAttributes[portStr]; ok {
132			return attributes.Label
133		}
134	}
135	return ""
136}
137
138var portFields = []string{
139	"sourcePort",
140	// "destinationPort", // TODO(mislav): this appears to always be blank?
141	"visibility",
142	"label",
143	"browseUrl",
144}
145
146func (pi *portInfo) ExportData(fields []string) map[string]interface{} {
147	data := map[string]interface{}{}
148
149	for _, f := range fields {
150		switch f {
151		case "sourcePort":
152			data[f] = pi.Port.SourcePort
153		case "destinationPort":
154			data[f] = pi.Port.DestinationPort
155		case "visibility":
156			data[f] = pi.Port.Privacy
157		case "label":
158			data[f] = pi.Label()
159		case "browseUrl":
160			data[f] = pi.BrowseURL()
161		default:
162			panic("unkown field: " + f)
163		}
164	}
165
166	return data
167}
168
169type devContainerResult struct {
170	devContainer *devContainer
171	err          error
172}
173
174type devContainer struct {
175	PortAttributes map[string]portAttribute `json:"portsAttributes"`
176}
177
178type portAttribute struct {
179	Label string `json:"label"`
180}
181
182func getDevContainer(ctx context.Context, apiClient apiClient, codespace *api.Codespace) <-chan devContainerResult {
183	ch := make(chan devContainerResult, 1)
184	go func() {
185		contents, err := apiClient.GetCodespaceRepositoryContents(ctx, codespace, ".devcontainer/devcontainer.json")
186		if err != nil {
187			ch <- devContainerResult{nil, fmt.Errorf("error getting content: %w", err)}
188			return
189		}
190
191		if contents == nil {
192			ch <- devContainerResult{nil, nil}
193			return
194		}
195
196		convertedJSON := normalizeJSON(jsonc.ToJSON(contents))
197		if !jsonc.Valid(convertedJSON) {
198			ch <- devContainerResult{nil, errors.New("failed to convert json to standard json")}
199			return
200		}
201
202		var container devContainer
203		if err := json.Unmarshal(convertedJSON, &container); err != nil {
204			ch <- devContainerResult{nil, fmt.Errorf("error unmarshaling: %w", err)}
205			return
206		}
207
208		ch <- devContainerResult{&container, nil}
209	}()
210	return ch
211}
212
213func newPortsVisibilityCmd(app *App) *cobra.Command {
214	return &cobra.Command{
215		Use:     "visibility <port>:{public|private|org}...",
216		Short:   "Change the visibility of the forwarded port",
217		Example: "gh codespace ports visibility 80:org 3000:private 8000:public",
218		Args:    cobra.MinimumNArgs(1),
219		RunE: func(cmd *cobra.Command, args []string) error {
220			codespace, err := cmd.Flags().GetString("codespace")
221			if err != nil {
222				// should only happen if flag is not defined
223				// or if the flag is not of string type
224				// since it's a persistent flag that we control it should never happen
225				return fmt.Errorf("get codespace flag: %w", err)
226			}
227			return app.UpdatePortVisibility(cmd.Context(), codespace, args)
228		},
229	}
230}
231
232func (a *App) UpdatePortVisibility(ctx context.Context, codespaceName string, args []string) (err error) {
233	ports, err := a.parsePortVisibilities(args)
234	if err != nil {
235		return fmt.Errorf("error parsing port arguments: %w", err)
236	}
237
238	codespace, err := getOrChooseCodespace(ctx, a.apiClient, codespaceName)
239	if err != nil {
240		if err == errNoCodespaces {
241			return err
242		}
243		return fmt.Errorf("error getting codespace: %w", err)
244	}
245
246	session, err := codespaces.ConnectToLiveshare(ctx, a, noopLogger(), a.apiClient, codespace)
247	if err != nil {
248		return fmt.Errorf("error connecting to codespace: %w", err)
249	}
250	defer safeClose(session, &err)
251
252	// TODO: check if port visibility can be updated in parallel instead of sequentially
253	for _, port := range ports {
254		a.StartProgressIndicatorWithLabel(fmt.Sprintf("Updating port %d visibility to: %s", port.number, port.visibility))
255		err := session.UpdateSharedServerPrivacy(ctx, port.number, port.visibility)
256		a.StopProgressIndicator()
257		if err != nil {
258			return fmt.Errorf("error update port to public: %w", err)
259		}
260	}
261
262	return nil
263}
264
265type portVisibility struct {
266	number     int
267	visibility string
268}
269
270func (a *App) parsePortVisibilities(args []string) ([]portVisibility, error) {
271	ports := make([]portVisibility, 0, len(args))
272	for _, a := range args {
273		fields := strings.Split(a, ":")
274		if len(fields) != 2 {
275			return nil, fmt.Errorf("invalid port visibility format for %q", a)
276		}
277		portStr, visibility := fields[0], fields[1]
278		portNumber, err := strconv.Atoi(portStr)
279		if err != nil {
280			return nil, fmt.Errorf("invalid port number: %w", err)
281		}
282		ports = append(ports, portVisibility{portNumber, visibility})
283	}
284	return ports, nil
285}
286
287// NewPortsForwardCmd returns a Cobra "ports forward" subcommand, which forwards a set of
288// port pairs from the codespace to localhost.
289func newPortsForwardCmd(app *App) *cobra.Command {
290	return &cobra.Command{
291		Use:   "forward <remote-port>:<local-port>...",
292		Short: "Forward ports",
293		Args:  cobra.MinimumNArgs(1),
294		RunE: func(cmd *cobra.Command, args []string) error {
295			codespace, err := cmd.Flags().GetString("codespace")
296			if err != nil {
297				// should only happen if flag is not defined
298				// or if the flag is not of string type
299				// since it's a persistent flag that we control it should never happen
300				return fmt.Errorf("get codespace flag: %w", err)
301			}
302
303			return app.ForwardPorts(cmd.Context(), codespace, args)
304		},
305	}
306}
307
308func (a *App) ForwardPorts(ctx context.Context, codespaceName string, ports []string) (err error) {
309	portPairs, err := getPortPairs(ports)
310	if err != nil {
311		return fmt.Errorf("get port pairs: %w", err)
312	}
313
314	codespace, err := getOrChooseCodespace(ctx, a.apiClient, codespaceName)
315	if err != nil {
316		if err == errNoCodespaces {
317			return err
318		}
319		return fmt.Errorf("error getting codespace: %w", err)
320	}
321
322	session, err := codespaces.ConnectToLiveshare(ctx, a, noopLogger(), a.apiClient, codespace)
323	if err != nil {
324		return fmt.Errorf("error connecting to codespace: %w", err)
325	}
326	defer safeClose(session, &err)
327
328	// Run forwarding of all ports concurrently, aborting all of
329	// them at the first failure, including cancellation of the context.
330	group, ctx := errgroup.WithContext(ctx)
331	for _, pair := range portPairs {
332		pair := pair
333		group.Go(func() error {
334			listen, err := net.Listen("tcp", fmt.Sprintf(":%d", pair.local))
335			if err != nil {
336				return err
337			}
338			defer listen.Close()
339
340			a.errLogger.Printf("Forwarding ports: remote %d <=> local %d", pair.remote, pair.local)
341			name := fmt.Sprintf("share-%d", pair.remote)
342			fwd := liveshare.NewPortForwarder(session, name, pair.remote, false)
343			return fwd.ForwardToListener(ctx, listen) // error always non-nil
344		})
345	}
346	return group.Wait() // first error
347}
348
349type portPair struct {
350	remote, local int
351}
352
353// getPortPairs parses a list of strings of form "%d:%d" into pairs of (remote, local) numbers.
354func getPortPairs(ports []string) ([]portPair, error) {
355	pp := make([]portPair, 0, len(ports))
356
357	for _, portString := range ports {
358		parts := strings.Split(portString, ":")
359		if len(parts) < 2 {
360			return nil, fmt.Errorf("port pair: %q is not valid", portString)
361		}
362
363		remote, err := strconv.Atoi(parts[0])
364		if err != nil {
365			return pp, fmt.Errorf("convert remote port to int: %w", err)
366		}
367
368		local, err := strconv.Atoi(parts[1])
369		if err != nil {
370			return pp, fmt.Errorf("convert local port to int: %w", err)
371		}
372
373		pp = append(pp, portPair{remote, local})
374	}
375
376	return pp, nil
377}
378
379func normalizeJSON(j []byte) []byte {
380	// remove trailing commas
381	return bytes.ReplaceAll(j, []byte("},}"), []byte("}}"))
382}
383