1package dotgit
2
3import (
4	"fmt"
5	"os"
6
7	"github.com/jesseduffield/go-git/v5/plumbing"
8	"github.com/jesseduffield/go-git/v5/utils/ioutil"
9
10	"github.com/go-git/go-billy/v5"
11)
12
13func (d *DotGit) setRef(fileName, content string, old *plumbing.Reference) (err error) {
14	if billy.CapabilityCheck(d.fs, billy.ReadAndWriteCapability) {
15		return d.setRefRwfs(fileName, content, old)
16	}
17
18	return d.setRefNorwfs(fileName, content, old)
19}
20
21func (d *DotGit) setRefRwfs(fileName, content string, old *plumbing.Reference) (err error) {
22	// If we are not checking an old ref, just truncate the file.
23	mode := os.O_RDWR | os.O_CREATE
24	if old == nil {
25		mode |= os.O_TRUNC
26	}
27
28	f, err := d.fs.OpenFile(fileName, mode, 0666)
29	if err != nil {
30		return err
31	}
32
33	defer ioutil.CheckClose(f, &err)
34
35	// Lock is unlocked by the deferred Close above. This is because Unlock
36	// does not imply a fsync and thus there would be a race between
37	// Unlock+Close and other concurrent writers. Adding Sync to go-billy
38	// could work, but this is better (and avoids superfluous syncs).
39	err = f.Lock()
40	if err != nil {
41		return err
42	}
43
44	// this is a no-op to call even when old is nil.
45	err = d.checkReferenceAndTruncate(f, old)
46	if err != nil {
47		return err
48	}
49
50	_, err = f.Write([]byte(content))
51	return err
52}
53
54// There are some filesystems that don't support opening files in RDWD mode.
55// In these filesystems the standard SetRef function can not be used as it
56// reads the reference file to check that it's not modified before updating it.
57//
58// This version of the function writes the reference without extra checks
59// making it compatible with these simple filesystems. This is usually not
60// a problem as they should be accessed by only one process at a time.
61func (d *DotGit) setRefNorwfs(fileName, content string, old *plumbing.Reference) error {
62	_, err := d.fs.Stat(fileName)
63	if err == nil && old != nil {
64		fRead, err := d.fs.Open(fileName)
65		if err != nil {
66			return err
67		}
68
69		ref, err := d.readReferenceFrom(fRead, old.Name().String())
70		fRead.Close()
71
72		if err != nil {
73			return err
74		}
75
76		if ref.Hash() != old.Hash() {
77			return fmt.Errorf("reference has changed concurrently")
78		}
79	}
80
81	f, err := d.fs.Create(fileName)
82	if err != nil {
83		return err
84	}
85
86	defer f.Close()
87
88	_, err = f.Write([]byte(content))
89	return err
90}
91