1# frozen_string_literal: true
2
3module Database
4  module TableSchemaHelpers
5    def connection
6      ActiveRecord::Base.connection
7    end
8
9    def expect_table_to_be_replaced(original_table:, replacement_table:, archived_table:)
10      original_oid = table_oid(original_table)
11      replacement_oid = table_oid(replacement_table)
12
13      yield
14
15      expect(table_oid(original_table)).to eq(replacement_oid)
16      expect(table_oid(archived_table)).to eq(original_oid)
17      expect(table_oid(replacement_table)).to be_nil
18    end
19
20    def expect_table_columns_to_match(expected_column_attributes, table_name)
21      expect(connection.table_exists?(table_name)).to eq(true)
22
23      actual_columns = connection.columns(table_name)
24      expect(actual_columns.size).to eq(column_attributes.size)
25
26      column_attributes.each_with_index do |attributes, i|
27        actual_column = actual_columns[i]
28
29        attributes.each do |name, value|
30          actual_value = actual_column.public_send(name)
31          message = "expected #{actual_column.name}.#{name} to be #{value}, but got #{actual_value}"
32
33          expect(actual_value).to eq(value), message
34        end
35      end
36    end
37
38    def expect_index_to_exist(name, schema: nil)
39      expect(index_exists_by_name(name, schema: schema)).to eq(true)
40    end
41
42    def expect_index_not_to_exist(name, schema: nil)
43      expect(index_exists_by_name(name, schema: schema)).to be_nil
44    end
45
46    def expect_foreign_key_to_exist(table_name, name, schema: nil)
47      expect(foreign_key_exists_by_name(table_name, name, schema: schema)).to eq(true)
48    end
49
50    def expect_foreign_key_not_to_exist(table_name, name, schema: nil)
51      expect(foreign_key_exists_by_name(table_name, name, schema: schema)).to be_nil
52    end
53
54    def expect_check_constraint(table_name, name, definition, schema: nil)
55      expect(check_constraint_definition(table_name, name, schema: schema)).to eq("CHECK ((#{definition}))")
56    end
57
58    def expect_primary_keys_after_tables(tables, schema: nil)
59      tables.each do |table|
60        primary_key = primary_key_constraint_name(table, schema: schema)
61
62        expect(primary_key).to eq("#{table}_pkey")
63      end
64    end
65
66    def table_oid(name)
67      connection.select_value(<<~SQL)
68        SELECT oid
69        FROM pg_catalog.pg_class
70        WHERE relname = '#{name}'
71      SQL
72    end
73
74    def table_type(name)
75      connection.select_value(<<~SQL)
76        SELECT
77          CASE class.relkind
78          WHEN 'r' THEN 'normal'
79          WHEN 'p' THEN 'partitioned'
80          ELSE 'other'
81          END as table_type
82        FROM pg_catalog.pg_class class
83        WHERE class.relname = '#{name}'
84      SQL
85    end
86
87    def sequence_owned_by(table_name, column_name)
88      connection.select_value(<<~SQL)
89        SELECT
90          sequence.relname as name
91        FROM pg_catalog.pg_class as sequence
92        INNER JOIN pg_catalog.pg_depend depend
93          ON depend.objid = sequence.oid
94        INNER JOIN pg_catalog.pg_class class
95          ON class.oid = depend.refobjid
96        INNER JOIN pg_catalog.pg_attribute attribute
97          ON attribute.attnum = depend.refobjsubid
98          AND attribute.attrelid = depend.refobjid
99        WHERE class.relname = '#{table_name}'
100          AND attribute.attname = '#{column_name}'
101      SQL
102    end
103
104    def default_expression_for(table_name, column_name)
105      connection.select_value(<<~SQL)
106        SELECT
107          pg_get_expr(attrdef.adbin, attrdef.adrelid) AS default_value
108        FROM pg_catalog.pg_attribute attribute
109        INNER JOIN pg_catalog.pg_attrdef attrdef
110          ON attribute.attrelid = attrdef.adrelid
111          AND attribute.attnum = attrdef.adnum
112        WHERE attribute.attrelid = '#{table_name}'::regclass
113          AND attribute.attname = '#{column_name}'
114      SQL
115    end
116
117    def primary_key_constraint_name(table_name, schema: nil)
118      table_name = schema ? "#{schema}.#{table_name}" : table_name
119
120      connection.select_value(<<~SQL)
121        SELECT
122          conname AS constraint_name
123        FROM pg_catalog.pg_constraint
124        WHERE pg_constraint.conrelid = '#{table_name}'::regclass
125          AND pg_constraint.contype = 'p'
126      SQL
127    end
128
129    def index_exists_by_name(index, schema: nil)
130      schema = schema ? "'#{schema}'" : 'current_schema'
131
132      connection.select_value(<<~SQL)
133        SELECT true
134        FROM pg_catalog.pg_index i
135        INNER JOIN pg_catalog.pg_class c
136          ON c.oid = i.indexrelid
137        INNER JOIN pg_catalog.pg_namespace n
138          ON c.relnamespace = n.oid
139        WHERE c.relname = '#{index}'
140          AND n.nspname = #{schema}
141      SQL
142    end
143
144    def foreign_key_exists_by_name(table_name, foreign_key_name, schema: nil)
145      table_name = schema ? "#{schema}.#{table_name}" : table_name
146
147      connection.select_value(<<~SQL)
148        SELECT true
149        FROM pg_catalog.pg_constraint
150        WHERE pg_constraint.conrelid = '#{table_name}'::regclass
151          AND pg_constraint.contype = 'f'
152          AND pg_constraint.conname = '#{foreign_key_name}'
153      SQL
154    end
155
156    def check_constraint_definition(table_name, constraint_name, schema: nil)
157      table_name = schema ? "#{schema}.#{table_name}" : table_name
158
159      connection.select_value(<<~SQL)
160        SELECT
161          pg_get_constraintdef(oid) AS constraint_definition
162        FROM pg_catalog.pg_constraint
163        WHERE pg_constraint.conrelid = '#{table_name}'::regclass
164          AND pg_constraint.contype = 'c'
165          AND pg_constraint.conname = '#{constraint_name}'
166      SQL
167    end
168  end
169end
170