1package callbacks
2
3import (
4	"reflect"
5	"strings"
6
7	"gorm.io/gorm"
8	"gorm.io/gorm/clause"
9	"gorm.io/gorm/schema"
10	"gorm.io/gorm/utils"
11)
12
13func BeforeDelete(db *gorm.DB) {
14	if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeDelete {
15		callMethod(db, func(value interface{}, tx *gorm.DB) bool {
16			if i, ok := value.(BeforeDeleteInterface); ok {
17				db.AddError(i.BeforeDelete(tx))
18				return true
19			}
20
21			return false
22		})
23	}
24}
25
26func DeleteBeforeAssociations(db *gorm.DB) {
27	if db.Error == nil && db.Statement.Schema != nil {
28		selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false)
29
30		if restricted {
31			for column, v := range selectColumns {
32				if v {
33					if rel, ok := db.Statement.Schema.Relationships.Relations[column]; ok {
34						switch rel.Type {
35						case schema.HasOne, schema.HasMany:
36							queryConds := rel.ToQueryConditions(db.Statement.ReflectValue)
37							modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
38							tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue)
39							withoutConditions := false
40							if db.Statement.Unscoped {
41								tx = tx.Unscoped()
42							}
43
44							if len(db.Statement.Selects) > 0 {
45								selects := make([]string, 0, len(db.Statement.Selects))
46								for _, s := range db.Statement.Selects {
47									if s == clause.Associations {
48										selects = append(selects, s)
49									} else if strings.HasPrefix(s, column+".") {
50										selects = append(selects, strings.TrimPrefix(s, column+"."))
51									}
52								}
53
54								if len(selects) > 0 {
55									tx = tx.Select(selects)
56								}
57							}
58
59							for _, cond := range queryConds {
60								if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 {
61									withoutConditions = true
62									break
63								}
64							}
65
66							if !withoutConditions {
67								if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
68									return
69								}
70							}
71						case schema.Many2Many:
72							var (
73								queryConds     = make([]clause.Expression, 0, len(rel.References))
74								foreignFields  = make([]*schema.Field, 0, len(rel.References))
75								relForeignKeys = make([]string, 0, len(rel.References))
76								modelValue     = reflect.New(rel.JoinTable.ModelType).Interface()
77								table          = rel.JoinTable.Table
78								tx             = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table)
79							)
80
81							for _, ref := range rel.References {
82								if ref.OwnPrimaryKey {
83									foreignFields = append(foreignFields, ref.PrimaryKey)
84									relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName)
85								} else if ref.PrimaryValue != "" {
86									queryConds = append(queryConds, clause.Eq{
87										Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName},
88										Value:  ref.PrimaryValue,
89									})
90								}
91							}
92
93							_, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, foreignFields)
94							column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues)
95							queryConds = append(queryConds, clause.IN{Column: column, Values: values})
96
97							if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil {
98								return
99							}
100						}
101					}
102				}
103			}
104		}
105	}
106}
107
108func Delete(config *Config) func(db *gorm.DB) {
109	supportReturning := utils.Contains(config.DeleteClauses, "RETURNING")
110
111	return func(db *gorm.DB) {
112		if db.Error != nil {
113			return
114		}
115
116		if db.Statement.Schema != nil && !db.Statement.Unscoped {
117			for _, c := range db.Statement.Schema.DeleteClauses {
118				db.Statement.AddClause(c)
119			}
120		}
121
122		if db.Statement.SQL.String() == "" {
123			db.Statement.SQL.Grow(100)
124			db.Statement.AddClauseIfNotExists(clause.Delete{})
125
126			if db.Statement.Schema != nil {
127				_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
128				column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
129
130				if len(values) > 0 {
131					db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
132				}
133
134				if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
135					_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
136					column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)
137
138					if len(values) > 0 {
139						db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
140					}
141				}
142			}
143
144			db.Statement.AddClauseIfNotExists(clause.From{})
145			db.Statement.Build(db.Statement.BuildClauses...)
146		}
147
148		if _, ok := db.Statement.Clauses["WHERE"]; !db.AllowGlobalUpdate && !ok && db.Error == nil {
149			db.AddError(gorm.ErrMissingWhereClause)
150			return
151		}
152
153		if !db.DryRun && db.Error == nil {
154			if ok, mode := hasReturning(db, supportReturning); ok {
155				if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
156					gorm.Scan(rows, db, mode)
157					rows.Close()
158				}
159			} else {
160				result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
161				if db.AddError(err) == nil {
162					db.RowsAffected, _ = result.RowsAffected()
163				}
164			}
165		}
166	}
167}
168
169func AfterDelete(db *gorm.DB) {
170	if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterDelete {
171		callMethod(db, func(value interface{}, tx *gorm.DB) bool {
172			if i, ok := value.(AfterDeleteInterface); ok {
173				db.AddError(i.AfterDelete(tx))
174				return true
175			}
176			return false
177		})
178	}
179}
180