1 /*-------------------------------------------------------------------------
2  *
3  * qualify_type_stmt.c
4  *	  Functions specialized in fully qualifying all type statements. These
5  *	  functions are dispatched from qualify.c
6  *
7  *	  Fully qualifying type statements consists of adding the schema name
8  *	  to the subject of the types as well as any other branch of the
9  *	  parsetree.
10  *
11  *	  Goal would be that the deparser functions for these statements can
12  *	  serialize the statement without any external lookups.
13  *
14  * Copyright (c) Citus Data, Inc.
15  *
16  *-------------------------------------------------------------------------
17  */
18 
19 #include "postgres.h"
20 
21 #include "access/heapam.h"
22 #include "access/htup_details.h"
23 #include "catalog/namespace.h"
24 #include "catalog/objectaddress.h"
25 #include "catalog/pg_type.h"
26 #include "distributed/commands.h"
27 #include "distributed/deparser.h"
28 #include "distributed/version_compat.h"
29 #include "nodes/makefuncs.h"
30 #include "parser/parse_type.h"
31 #include "utils/syscache.h"
32 #include "utils/lsyscache.h"
33 
34 static char * GetTypeNamespaceNameByNameList(List *names);
35 static Oid TypeOidGetNamespaceOid(Oid typeOid);
36 
37 /*
38  * GetTypeNamespaceNameByNameList resolved the schema name of a type by its namelist.
39  */
40 static char *
GetTypeNamespaceNameByNameList(List * names)41 GetTypeNamespaceNameByNameList(List *names)
42 {
43 	TypeName *typeName = makeTypeNameFromNameList(names);
44 	Oid typeOid = LookupTypeNameOid(NULL, typeName, false);
45 	Oid namespaceOid = TypeOidGetNamespaceOid(typeOid);
46 	char *nspname = get_namespace_name_or_temp(namespaceOid);
47 	return nspname;
48 }
49 
50 
51 /*
52  * TypeOidGetNamespaceOid resolves the namespace oid for a type identified by its type oid
53  */
54 static Oid
TypeOidGetNamespaceOid(Oid typeOid)55 TypeOidGetNamespaceOid(Oid typeOid)
56 {
57 	HeapTuple typeTuple = SearchSysCache1(TYPEOID, typeOid);
58 
59 	if (!HeapTupleIsValid(typeTuple))
60 	{
61 		elog(ERROR, "citus cache lookup failed");
62 		return InvalidOid;
63 	}
64 	Form_pg_type typeData = (Form_pg_type) GETSTRUCT(typeTuple);
65 	Oid typnamespace = typeData->typnamespace;
66 
67 	ReleaseSysCache(typeTuple);
68 
69 	return typnamespace;
70 }
71 
72 
73 void
QualifyRenameTypeStmt(Node * node)74 QualifyRenameTypeStmt(Node *node)
75 {
76 	RenameStmt *stmt = castNode(RenameStmt, node);
77 	List *names = (List *) stmt->object;
78 
79 	Assert(stmt->renameType == OBJECT_TYPE);
80 
81 	if (list_length(names) == 1)
82 	{
83 		/* not qualified, lookup name and add namespace name to names */
84 		char *nspname = GetTypeNamespaceNameByNameList(names);
85 		names = list_make2(makeString(nspname), linitial(names));
86 
87 		stmt->object = (Node *) names;
88 	}
89 }
90 
91 
92 void
QualifyRenameTypeAttributeStmt(Node * node)93 QualifyRenameTypeAttributeStmt(Node *node)
94 {
95 	RenameStmt *stmt = castNode(RenameStmt, node);
96 	Assert(stmt->renameType == OBJECT_ATTRIBUTE);
97 	Assert(stmt->relationType == OBJECT_TYPE);
98 
99 	if (stmt->relation->schemaname == NULL)
100 	{
101 		List *names = list_make1(makeString(stmt->relation->relname));
102 		char *nspname = GetTypeNamespaceNameByNameList(names);
103 		stmt->relation->schemaname = nspname;
104 	}
105 }
106 
107 
108 void
QualifyAlterEnumStmt(Node * node)109 QualifyAlterEnumStmt(Node *node)
110 {
111 	AlterEnumStmt *stmt = castNode(AlterEnumStmt, node);
112 	List *names = stmt->typeName;
113 
114 	if (list_length(names) == 1)
115 	{
116 		/* not qualified, lookup name and add namespace name to names */
117 		char *nspname = GetTypeNamespaceNameByNameList(names);
118 		names = list_make2(makeString(nspname), linitial(names));
119 
120 		stmt->typeName = names;
121 	}
122 }
123 
124 
125 void
QualifyAlterTypeStmt(Node * node)126 QualifyAlterTypeStmt(Node *node)
127 {
128 	AlterTableStmt *stmt = castNode(AlterTableStmt, node);
129 	Assert(AlterTableStmtObjType_compat(stmt) == OBJECT_TYPE);
130 
131 	if (stmt->relation->schemaname == NULL)
132 	{
133 		List *names = MakeNameListFromRangeVar(stmt->relation);
134 		char *nspname = GetTypeNamespaceNameByNameList(names);
135 		stmt->relation->schemaname = nspname;
136 	}
137 }
138 
139 
140 void
QualifyCompositeTypeStmt(Node * node)141 QualifyCompositeTypeStmt(Node *node)
142 {
143 	CompositeTypeStmt *stmt = castNode(CompositeTypeStmt, node);
144 
145 	if (stmt->typevar->schemaname == NULL)
146 	{
147 		Oid creationSchema = RangeVarGetCreationNamespace(stmt->typevar);
148 		stmt->typevar->schemaname = get_namespace_name(creationSchema);
149 	}
150 }
151 
152 
153 void
QualifyCreateEnumStmt(Node * node)154 QualifyCreateEnumStmt(Node *node)
155 {
156 	CreateEnumStmt *stmt = castNode(CreateEnumStmt, node);
157 
158 	if (list_length(stmt->typeName) == 1)
159 	{
160 		char *objname = NULL;
161 		Oid creationSchema = QualifiedNameGetCreationNamespace(stmt->typeName, &objname);
162 		stmt->typeName = list_make2(makeString(get_namespace_name(creationSchema)),
163 									linitial(stmt->typeName));
164 	}
165 }
166 
167 
168 void
QualifyAlterTypeSchemaStmt(Node * node)169 QualifyAlterTypeSchemaStmt(Node *node)
170 {
171 	AlterObjectSchemaStmt *stmt = castNode(AlterObjectSchemaStmt, node);
172 	Assert(stmt->objectType == OBJECT_TYPE);
173 
174 	List *names = (List *) stmt->object;
175 	if (list_length(names) == 1)
176 	{
177 		/* not qualified with schema, lookup type and its schema s*/
178 		char *nspname = GetTypeNamespaceNameByNameList(names);
179 		names = list_make2(makeString(nspname), linitial(names));
180 		stmt->object = (Node *) names;
181 	}
182 }
183 
184 
185 void
QualifyAlterTypeOwnerStmt(Node * node)186 QualifyAlterTypeOwnerStmt(Node *node)
187 {
188 	AlterOwnerStmt *stmt = castNode(AlterOwnerStmt, node);
189 	Assert(stmt->objectType == OBJECT_TYPE);
190 
191 	List *names = (List *) stmt->object;
192 	if (list_length(names) == 1)
193 	{
194 		/* not qualified with schema, lookup type and its schema s*/
195 		char *nspname = GetTypeNamespaceNameByNameList(names);
196 		names = list_make2(makeString(nspname), linitial(names));
197 		stmt->object = (Node *) names;
198 	}
199 }
200