1-- xpattern.lua
2-- Preliminary regular expression-like support in Lua
3-- Uses Lua patterns as the core building block.
4--
5-- Implemented in pure Lua with code generation technique.
6-- It translates an expression into a snippet of Lua code
7-- having a series of `string.match` calls, which is then
8-- compiled (via `loadstring`).
9--
10-- Like lpeg, does not support backtracking.
11--
12-- WARNING: This is experimental code.  The design and implementation
13-- has not been thoroughly tested.
14--
15-- Version v20091021.
16-- (c) 2008-2009 David Manura. Licensed under the same terms as Lua (MIT license).
17-- Please post patches.
18
19M = {}
20
21local string = string
22local format = string.format
23local match  = string.match
24local assert = assert
25local error  = error
26local ipairs = ipairs
27local loadstring   = loadstring
28local setmetatable = setmetatable
29local type   = type
30local print  = print
31
32
33-- Adds whitespace to string `s`.
34-- Whitespace string `ws` (default to '' if omitted) is prepended to each line
35-- of `s`.  Also ensures `s` is is terminated by a newline.
36local function add_whitespace(s, ws)
37  ws = ws or ''
38  s = s:gsub('[^\r\n]+', ws .. '%1')
39  if s:match('[^\r\n]$') then
40    s = s .. '\n'
41  end
42  return s
43end
44
45-- Counts the number `count` of captures '()' in Lua pattern string `pat`.
46local function count_captures(pat)
47  local count = 0
48  local pos = 1
49  while pos <= #pat do
50    local pos2 = pat:match('^[^%(%%%[]+()', pos)
51    if pos2 then
52      pos = pos2
53    elseif pat:match('^%(', pos) then
54      count = count + 1
55      pos = pos + 1
56    elseif pat:match('^%%b..', pos) then
57      pos = pos + 3
58    elseif pat:match('^%%.', pos) then
59      pos = pos + 2
60    else
61      local pos2 = pat:match('^%[[^%]%%]*()', pos)
62      if pos2 then
63        pos = pos2
64        while 1 do
65          local pos2 = pat:match('^%%.[^%]%%]*()', pos)
66          if pos2 then
67            pos = pos2
68          elseif pat:match('^%]', pos) then
69            pos = pos + 1
70            break
71          else
72            error('syntax', 2)
73          end
74        end
75      else
76        error('syntax', 2)
77      end
78    end
79  end
80  return count
81end
82M._count_captures = count_captures
83
84
85-- Appends '()' to Lua pattern string `pat`.
86local function pat_append_pos(pat)
87  local prefix = pat:match'^(.*)%$$'
88  pat = prefix and prefix .. '()$' or pat .. '()'
89  return pat
90end
91
92-- Prepends '()' to Lua pattern string `pat`.
93local function pat_prepend_pos(pat)
94  local postfix = pat:match'^%^(.*)'
95  pat = postfix and '^()' .. postfix or '()' .. pat
96  return pat
97end
98
99
100-- Prepends '^' to Lua pattern string `pat`.
101local function pat_prepend_carrot(pat)
102  local postfix = pat:match'^%^(.*)'
103  pat = postfix and pat or '^' .. pat
104  return pat
105end
106
107
108-- Return a string listing pattern capture variables with indices `firstidx`
109-- to `lastidx`.
110-- Ex: code_vars(1,2) --> 'c1,c2'
111local function code_vars(firstidx, lastidx)
112  local code = ''
113  for i=firstidx,lastidx do
114    code = code .. (i == firstidx and '' or ',') .. 'c' .. i
115  end
116  return code
117end
118
119
120-- Metatable for expression objects
121local epat_mt = {}
122epat_mt.__index = epat_mt
123
124
125-- Builds an extended pattern object `epat` from Lua string pattern `pat`.
126local function pattern(pat)
127  local epat = setmetatable({}, epat_mt)
128  epat.call = function(srcidx0, destidx0, totncaptures0)
129    local ncaptures = count_captures(pat)
130    local lvars =
131      code_vars(totncaptures0+1, totncaptures0+ncaptures)
132      .. (ncaptures == 0 and '' or ',') .. 'pos' .. destidx0
133    local pat = pat_append_pos(pat)
134
135    pat = pat_prepend_carrot(pat)
136
137    local str = format('%q', pat)
138    local code = lvars .. ' = match(s, ' .. str .. ', pos' .. srcidx0 .. ')\n'
139    return code, ncaptures
140  end
141  epat.anchored = pat:sub(1,1) == '^'
142  return epat
143end
144
145
146-- Generates code from pattern `anypat` (either Lua pattern string or extended
147-- pattern object).
148--  `anypat`    - either Lua pattern string or extended pattern object
149--  `srcidx0`   - index of variable holding position to start matching at
150--  `destidx0`  - index of variable holding position to store subsequent
151--                match position at.  stores nil if no match
152--  `totncaptures0` - number of captures prior to this match
153--  `code`      - Lua code string (code) and number of
154--  `ncaptures` - number of captures in pattern.
155local function gen(anypat, srcidx0, destidx0, totncaptures0)
156  if type(anypat) == 'string' then
157    anypat = pat_prepend_carrot(anypat)
158    anypat = pattern(anypat)
159  end
160  local code, ncaptures = anypat(srcidx0, destidx0, totncaptures0)
161  return code, ncaptures
162end
163
164
165-- Creates a new extended pattern object `epat` that is the concatenation of
166-- the given list (of size >= 0) of pattern objects.
167-- Specify a single string argument to convert a Lua pattern to an extended
168-- pattern object.
169local function seq(...) -- epats...
170  -- Ensure args are extended pattern objects.
171  local epats = {...}
172  for i=1,#epats do
173    if type(epats[i]) == 'string' then
174      epats[i] = pattern(epats[i])
175    end
176  end
177
178  local epat = setmetatable({}, epat_mt)
179  epat.call = function(srcidx0, destidx0, totncaptures0, ws)
180    if #epats == 0 then
181      return 'pos' .. destidx0 .. ' = pos' .. srcidx0 .. '\n', 0
182    elseif #epats == 1 then
183      return epats[1](srcidx0, destidx0, totncaptures0, ws)
184    else
185      local ncaptures = 0
186      local pat_code, pat_ncaptures =
187          gen(epats[1], srcidx0, destidx0, totncaptures0+ncaptures, true)
188      ncaptures = ncaptures + pat_ncaptures
189      local code = add_whitespace(pat_code, '')
190
191      for i=2,#epats do
192        local pat_code, pat_ncaptures =
193            gen(epats[i], destidx0, destidx0, totncaptures0+ncaptures, true)
194        ncaptures = ncaptures + pat_ncaptures
195        code =
196          code ..
197          'if pos' .. destidx0 .. ' then\n' ..
198            add_whitespace(pat_code, '  ') ..
199          'end\n'
200      end
201      return code, ncaptures
202    end
203  end
204  if epats[1] and epats[1].anchored then
205    epat.anchored = true
206  end
207  return epat
208end
209M.P = seq
210
211
212-- Creates new extended pattern object `epat` that is the alternation of the
213-- given list of pattern objects `epats...`.
214local function alt(...) -- epats...
215  -- Ensure args are extended pattern objects.
216  local epats = {...}
217  for i=1,#epats do
218    if type(epats[i]) == 'string' then
219      epats[i] = pattern(epats[i])
220    end
221  end
222
223  local epat = setmetatable({}, epat_mt)
224  epat.call = function(srcidx0, destidx0, totncaptures0)
225    if #epats == 0 then
226      return 'pos' .. destidx0 .. ' = pos' .. srcidx0 .. '\n', 0
227    elseif #epats == 1 then
228      return epats[1](srcidx0, destidx0, totncaptures0)
229    else
230      local ncaptures = 0
231      local pat_code, pat_ncaptures =
232          gen(epats[1], srcidx0, destidx0+1, totncaptures0+ncaptures, true)
233      ncaptures = ncaptures + pat_ncaptures
234      local code = 'local pos' .. destidx0+1 .. ' = pos' .. srcidx0 .. '\n' ..
235                   add_whitespace(pat_code, '')
236
237      for i=2,#epats do
238        local pat_code, pat_ncaptures =
239            gen(epats[i], srcidx0, destidx0+1, totncaptures0+ncaptures, true)
240        ncaptures = ncaptures + pat_ncaptures
241        code =
242          code ..
243          'if not pos' .. destidx0+1 .. ' then\n' ..
244            add_whitespace(pat_code, '  ') ..
245          'end\n'
246      end
247      code = code .. 'pos' .. destidx0 .. ' = pos' .. destidx0+1 .. '\n'
248      return code, ncaptures
249    end
250  end
251  return epat
252end
253M.alt = alt
254
255
256-- Creates new extended pattern object `epat` that is zero or more repetitions
257-- of the given pattern object `pat` (if one evaluates to false) or one or more
258-- (if one evaluates to true).
259local function star(pat, one)
260  local epat = setmetatable({}, epat_mt)
261  epat.call = function(srcidx0, destidx0, totncaptures0)
262    local ncaptures = 0
263    local destidx = destidx0 + 1
264    local code = 'do\n' ..
265                 '  local pos' .. destidx .. '=pos' .. srcidx0 .. '\n'
266    local pat_code, pat_ncaptures =
267        gen(pat, destidx, destidx, totncaptures0+ncaptures, true)
268    ncaptures = ncaptures + pat_ncaptures
269    code = code ..
270      (one and ('  pos' .. destidx0 .. ' = nil\n')
271           or  ('  pos' .. destidx0 .. ' = pos' .. srcidx0 .. '\n')) ..
272      '  while 1 do\n' ..
273           add_whitespace(pat_code, '    ') ..
274      '    if pos' .. destidx .. ' then\n' ..
275      '      pos' .. destidx0 .. '=pos' .. destidx .. '\n' ..
276      '    else break end\n' ..
277      '  end\n' ..
278      'end\n'
279    return code, ncaptures
280  end
281  return epat
282end
283M.star = star
284
285
286-- Creates new extended pattern object `epat` that is zero or one of the
287-- given pattern object `epat0`.
288local function zero_or_one(epat0)
289  local epat = setmetatable({}, epat_mt)
290  epat.call = function(srcidx0, destidx0, totncaptures0)
291    local ncaptures = 0
292    local destidx = destidx0 + 1
293    local code = 'do\n' ..
294                 '  local pos' .. destidx .. '=pos' .. srcidx0 .. '\n'
295    local epat0_code, epat0_ncaptures =
296        gen(epat0, destidx, destidx, totncaptures0+ncaptures, true)
297    ncaptures = ncaptures + epat0_ncaptures
298    code = code ..
299      add_whitespace(epat0_code) ..
300      '  if pos' .. destidx .. ' then\n' ..
301      '    pos' .. destidx0 .. '=pos' .. destidx .. '\n' ..
302      '  else\n' ..
303      '    pos' .. destidx0 .. '=pos' .. srcidx0 .. '\n' ..
304      '  end\n' ..
305      'end\n'
306    return code, ncaptures
307  end
308  return epat
309end
310M.zero_or_one = zero_or_one
311
312
313-- Gets Lua core code string `code` corresponding to pattern object `epat`
314local function basic_code_of(epat)
315  local pat_code, ncaptures = epat(1, 2, 0, true)
316  local lvars = code_vars(1, ncaptures)
317
318  if epat.anchored then
319    local code =
320      'local pos1=pos or 1\n' ..
321      'local pos2\n' ..
322      (lvars ~= '' and '  local ' .. lvars .. '\n' or '') ..
323      add_whitespace(pat_code) .. '\n' ..
324      'if pos2 then return ' .. (lvars ~= '' and lvars or 's:sub(pos1,pos2-1)') .. ' end\n'
325    return code
326  else
327    local code =
328        'for pos1=(pos or 1),#s do\n' ..
329        '  local pos2\n'
330    if lvars == '' then
331      code =
332        code ..
333           add_whitespace(pat_code, '  ') ..
334        '  if pos2 then return s:sub(pos1,pos2-1) end\n'
335    else
336      code =
337        code ..
338        '  local ' .. lvars .. '\n' ..
339           add_whitespace(pat_code, '  ') ..
340        '  if pos2 then return ' .. lvars .. ' end\n'
341    end
342    code =
343        code ..
344        'end\n'
345    return code
346  end
347end
348M.basic_code_of = basic_code_of
349
350
351-- Gets Lua complete code string `code` corresponding to pattern object `epat`.
352local function code_of(epat)
353  local code =
354    'local match = ...\n' ..
355    'return function(s,pos)\n' ..
356    add_whitespace(basic_code_of(epat), '  ') ..
357    'end\n'
358  return code
359end
360M.code_of = code_of
361
362
363-- Compiles pattern object `epat` to Lua function `f`.
364local function compile(epat)
365  local code = code_of(epat)
366  if M.debug then print('DEBUG:\n' .. code) end
367  local f = assert(loadstring(code))(match)
368  return f
369end
370M.compile = compile
371
372
373-- operator for pattern matching
374function epat_mt.__call(epat, ...)
375  return epat.call(...)
376end
377
378
379-- operator for pattern alternation
380function epat_mt.__add(a_epat, b_epat)
381  return alt(a_epat, b_epat)
382end
383
384
385-- operator for pattern concatenation
386function epat_mt.__mul(a_epat, b_epat)
387  return seq(a_epat, b_epat)
388end
389
390
391-- operator for pattern repetition
392function epat_mt.__pow(epat, n)
393  if n == 0 then
394    return star(epat)
395  elseif n == 1 then
396    return star(epat, true)
397  elseif n == -1 then
398    return zero_or_one(epat)
399  else
400    error 'FIX - unimplemented'
401  end
402end
403
404
405-- IMPROVE design?
406epat_mt.compile       = compile
407epat_mt.basic_code_of = basic_code_of
408epat_mt.code_of       = code_of
409
410
411