1# -*- coding: utf-8 -*-
2from __future__ import (nested_scopes, generators, division, absolute_import,
3                        with_statement, print_function, unicode_literals)
4import os
5import sys
6import multiprocessing as mltp
7import subprocess as sub
8import shutil as sht
9
10from grass.script.setup import write_gisrc
11
12from grass.pygrass.gis import Mapset, Location
13from grass.pygrass.gis.region import Region
14from grass.pygrass.modules import Module
15from grass.pygrass.utils import get_mapset_raster, findmaps
16
17from grass.pygrass.modules.grid.split import split_region_tiles
18from grass.pygrass.modules.grid.patch import rpatch_map
19
20
21def select(parms, ptype):
22    """Select only a  certain type of parameters.
23
24    :param parms: a DictType parameter with inputs or outputs of a Module class
25    :type parms: DictType parameters
26    :param ptype: String define the type of parameter that we want to select,
27                  valid ptype are: 'raster', 'vector', 'group'
28    :type ptype: str
29    :returns: An iterator with the value of the parameter.
30
31    >>> slp = Module('r.slope.aspect',
32    ...              elevation='ele', slope='slp', aspect='asp',
33    ...              run_=False)
34    >>> for rast in select(slp.outputs, 'raster'):
35    ...     print(rast)
36    ...
37    slp
38    asp
39    """
40    for k in parms:
41        par = parms[k]
42        if par.type == ptype or par.typedesc == ptype and par.value:
43            if par.multiple:
44                for val in par.value:
45                    yield val
46            else:
47                yield par.value
48
49
50def copy_special_mapset_files(path_src, path_dst):
51    """Copy all the special GRASS files that are contained in
52    a mapset to another mapset
53
54    :param path_src: the path to the original mapset
55    :type path_src: str
56    :param path_dst: the path to the new mapset
57    :type path_dst: str
58    """
59    for fil in (fi for fi in os.listdir(path_src) if fi.isupper()):
60        sht.copy(os.path.join(path_src, fil), path_dst)
61
62
63def copy_mapset(mapset, path):
64    """Copy mapset to another place without copying raster and vector data.
65
66    :param mapset: a Mapset instance to copy
67    :type mapset: Mapset object
68    :param path: path where the new mapset must be copied
69    :type path: str
70    :returns: the instance of the new Mapset.
71
72
73    >>> from grass.script.core import gisenv
74    >>> mname = gisenv()['MAPSET']
75    >>> mset = Mapset()
76    >>> mset.name == mname
77    True
78    >>> import tempfile as tmp
79    >>> import os
80    >>> path = os.path.join(tmp.gettempdir(), 'my_loc', 'my_mset')
81    >>> copy_mapset(mset, path)                           # doctest: +ELLIPSIS
82    Mapset(...)
83    >>> sorted(os.listdir(path))                          # doctest: +ELLIPSIS
84    [...'PERMANENT'...]
85    >>> sorted(os.listdir(os.path.join(path, 'PERMANENT')))
86    ['DEFAULT_WIND', 'PROJ_INFO', 'PROJ_UNITS', 'VAR', 'WIND']
87    >>> sorted(os.listdir(os.path.join(path, mname)))   # doctest: +ELLIPSIS
88    [...'SEARCH_PATH',...'WIND']
89    >>> import shutil
90    >>> shutil.rmtree(path)
91
92    """
93    per_old = os.path.join(mapset.gisdbase, mapset.location, 'PERMANENT')
94    per_new = os.path.join(path, 'PERMANENT')
95    map_old = mapset.path()
96    map_new = os.path.join(path, mapset.name)
97    if not os.path.isdir(per_new):
98        os.makedirs(per_new)
99    if not os.path.isdir(map_new):
100        os.mkdir(map_new)
101    copy_special_mapset_files(per_old, per_new)
102    copy_special_mapset_files(map_old, map_new)
103    gisdbase, location = os.path.split(path)
104    return Mapset(mapset.name, location, gisdbase)
105
106
107def read_gisrc(gisrc):
108    """Read a GISRC file and return a tuple with the mapset, location
109    and gisdbase.
110
111    :param gisrc: the path to GISRC file
112    :type gisrc: str
113    :returns: a tuple with the mapset, location and gisdbase
114
115    >>> import os
116    >>> from grass.script.core import gisenv
117    >>> genv = gisenv()
118    >>> (read_gisrc(os.environ['GISRC']) == (genv['MAPSET'],
119    ...                                      genv['LOCATION_NAME'],
120    ...                                      genv['GISDBASE']))
121    True
122    """
123    with open(gisrc, 'r') as gfile:
124        gis = dict([(k.strip(), v.strip())
125                    for k, v in [row.split(':', 1) for row in gfile]])
126    return gis['MAPSET'], gis['LOCATION_NAME'], gis['GISDBASE']
127
128
129def get_mapset(gisrc_src, gisrc_dst):
130    """Get mapset from a GISRC source to a GISRC destination.
131
132    :param gisrc_src: path to the GISRC source
133    :type gisrc_src: str
134    :param gisrc_dst: path to the GISRC destination
135    :type gisrc_dst: str
136    :returns: a tuple with Mapset(src), Mapset(dst)
137
138    """
139    msrc, lsrc, gsrc = read_gisrc(gisrc_src)
140    mdst, ldst, gdst = read_gisrc(gisrc_dst)
141    path_src = os.path.join(gsrc, lsrc, msrc)
142    path_dst = os.path.join(gdst, ldst, mdst)
143    if not os.path.isdir(path_dst):
144        os.makedirs(path_dst)
145        copy_special_mapset_files(path_src, path_dst)
146    src = Mapset(msrc, lsrc, gsrc)
147    dst = Mapset(mdst, ldst, gdst)
148    visible = [m for m in src.visible]
149    if src.name not in visible:
150        visible.append(src.name)
151    dst.visible.extend(visible)
152    return src, dst
153
154
155def copy_groups(groups, gisrc_src, gisrc_dst, region=None):
156    """Copy group from one mapset to another, crop the raster to the region
157
158    :param groups: a list of strings with the group that must be copied
159                   from a master to another.
160    :type groups: list of strings
161    :param gisrc_src: path of the GISRC file from where we want to copy the groups
162    :type gisrc_src: str
163    :param gisrc_dst: path of the GISRC file where the groups will be created
164    :type gisrc_dst: str
165    :param region: a region like object or a dictionary with the region
166                   parameters that will be used to crop the rasters of the
167                   groups
168    :type region: Region object or dictionary
169    :returns: None
170
171    """
172    env = os.environ.copy()
173    # instantiate modules
174    get_grp = Module('i.group', flags='lg', stdout_=sub.PIPE, run_=False)
175    set_grp = Module('i.group')
176    get_grp.run_ = True
177    rmloc = lambda r: r.split('@')[0] if '@' in r else r
178
179    src = read_gisrc(gisrc_src)
180    dst = read_gisrc(gisrc_dst)
181    rm = True if src[2] != dst[2] else False
182    all_rasts = [r[0]
183                 for r in findmaps('raster', location=dst[1], gisdbase=dst[2])]
184    for grp in groups:
185        # change gisdbase to src
186        env['GISRC'] = gisrc_src
187        get_grp(group=grp, env_=env)
188        rasts = [r for r in get_grp.outputs.stdout.split()]
189        # change gisdbase to dst
190        env['GISRC'] = gisrc_dst
191        rast2cp = [r for r in rasts if rmloc(r) not in all_rasts]
192        if rast2cp:
193            copy_rasters(rast2cp, gisrc_src, gisrc_dst, region=region)
194        set_grp(group=grp,
195                input=[rmloc(r) for r in rasts] if rast2cp or rm else rasts,
196                env_=env)
197
198
199def set_region(region, gisrc_src, gisrc_dst, env):
200    """Set a region into two different mapsets.
201
202    :param region: a region like object or a dictionary with the region
203                   parameters that will be used to crop the rasters of the
204                   groups
205    :type region: Region object or dictionary
206    :param gisrc_src: path of the GISRC file from where we want to copy the groups
207    :type gisrc_src: str
208    :param gisrc_dst: path of the GISRC file where the groups will be created
209    :type gisrc_dst: str
210    :param env:
211    :type env:
212    :returns: None
213    """
214    reg_str = "g.region n=%(north)r s=%(south)r " \
215              "e=%(east)r w=%(west)r " \
216              "nsres=%(nsres)r ewres=%(ewres)r"
217    reg_cmd = reg_str % dict(region.items())
218    env['GISRC'] = gisrc_src
219    sub.Popen(reg_cmd, shell=True, env=env)
220    env['GISRC'] = gisrc_dst
221    sub.Popen(reg_cmd, shell=True, env=env)
222
223
224def copy_rasters(rasters, gisrc_src, gisrc_dst, region=None):
225    """Copy rasters from one mapset to another, crop the raster to the region.
226
227    :param rasters: a list of strings with the raster map that must be copied
228                    from a master to another.
229    :type rasters: list
230    :param gisrc_src: path of the GISRC file from where we want to copy the groups
231    :type gisrc_src: str
232    :param gisrc_dst: path of the GISRC file where the groups will be created
233    :type gisrc_dst: str
234    :param region: a region like object or a dictionary with the region
235                   parameters that will be used to crop the rasters of the
236                   groups
237    :type region: Region object or dictionary
238    :returns: None
239    """
240    env = os.environ.copy()
241    if region:
242        set_region(region, gisrc_src, gisrc_dst, env)
243
244    path_dst = os.path.join(*read_gisrc(gisrc_dst)[::-1])
245    nam = "copy%d__%s" % (id(gisrc_dst), '%s')
246
247    # instantiate modules
248    mpclc = Module('r.mapcalc')
249    rpck = Module('r.pack')
250    rupck = Module('r.unpack')
251    remove = Module('g.remove')
252
253    for rast in rasters:
254        rast_clean = rast.split('@')[0] if '@' in rast else rast
255        # change gisdbase to src
256        env['GISRC'] = gisrc_src
257        name = nam % rast_clean
258        mpclc(expression="%s=%s" % (name, rast), overwrite=True, env_=env)
259        file_dst = "%s.pack" % os.path.join(path_dst, name)
260        rpck(input=name, output=file_dst, overwrite=True, env_=env)
261        remove(flags='f', type='raster', name=name, env_=env)
262        # change gisdbase to dst
263        env['GISRC'] = gisrc_dst
264        rupck(input=file_dst, output=rast_clean, overwrite=True, env_=env)
265        os.remove(file_dst)
266
267
268def copy_vectors(vectors, gisrc_src, gisrc_dst):
269    """Copy vectors from one mapset to another, crop the raster to the region.
270
271    :param vectors: a list of strings with the vector map that must be copied
272                    from a master to another.
273    :type vectors: list
274    :param gisrc_src: path of the GISRC file from where we want to copy the groups
275    :type gisrc_src: str
276    :param gisrc_dst: path of the GISRC file where the groups will be created
277    :type gisrc_dst: str
278    :returns: None
279    """
280    env = os.environ.copy()
281    path_dst = os.path.join(*read_gisrc(gisrc_dst))
282    nam = "copy%d__%s" % (id(gisrc_dst), '%s')
283
284    # instantiate modules
285    vpck = Module('v.pack')
286    vupck = Module('v.unpack')
287    remove = Module('g.remove')
288
289    for vect in vectors:
290        # change gisdbase to src
291        env['GISRC'] = gisrc_src
292        name = nam % vect
293        file_dst = "%s.pack" % os.path.join(path_dst, name)
294        vpck(input=name, output=file_dst, overwrite=True, env_=env)
295        remove(flags='f', type='vector', name=name, env_=env)
296        # change gisdbase to dst
297        env['GISRC'] = gisrc_dst
298        vupck(input=file_dst, output=vect, overwrite=True, env_=env)
299        os.remove(file_dst)
300
301
302def get_cmd(cmdd):
303    """Transform a cmd dictionary to a list of parameters. It is useful to
304    pickle a Module class and cnvert into a string that can be used with
305    `Popen(get_cmd(cmdd), shell=True)`.
306
307    :param cmdd: a module dictionary with all the parameters
308    :type cmdd: dict
309
310    >>> slp = Module('r.slope.aspect',
311    ...              elevation='ele', slope='slp', aspect='asp',
312    ...              overwrite=True, run_=False)
313    >>> get_cmd(slp.get_dict())  # doctest: +ELLIPSIS
314    ['r.slope.aspect', 'elevation=ele', 'format=degrees', ..., '--o']
315    """
316    cmd = [cmdd['name'], ]
317    cmd.extend(("%s=%s" % (k, v) for k, v in cmdd['inputs']
318                if not isinstance(v, list)))
319    cmd.extend(("%s=%s" % (k, ','.join(vals if isinstance(vals[0], str)
320                                       else [repr(v) for v in vals]))
321                for k, vals in cmdd['inputs']
322                if isinstance(vals, list)))
323    cmd.extend(("%s=%s" % (k, v) for k, v in cmdd['outputs']
324                if not isinstance(v, list)))
325    cmd.extend(("%s=%s" % (k, ','.join([repr(v) for v in vals]))
326                for k, vals in cmdd['outputs'] if isinstance(vals, list)))
327    cmd.extend(("-%s" % (flg) for flg in cmdd['flags'] if len(flg) == 1))
328    cmd.extend(("--%s" % (flg[0]) for flg in cmdd['flags'] if len(flg) > 1))
329    return cmd
330
331
332def cmd_exe(args):
333    """Create a mapset, and execute a cmd inside.
334
335    :param args: is a tuple that contains several information see below
336    :type args: tuple
337    :returns: None
338
339    The puple has to contain:
340
341    - bbox (dict): a dict with the region parameters (n, s, e, w, etc.)
342      that we want to set before to apply the command.
343    - mapnames (dict): a dictionary to substitute the input if the domain has
344      been split in several tiles.
345    - gisrc_src (str): path of the GISRC file from where we want to copy the
346      groups.
347    - gisrc_dst (str): path of the GISRC file where the groups will be created.
348    - cmd (dict): a dictionary with all the parameter of a GRASS module.
349    - groups (list): a list of strings with the groups that we want to copy in
350      the mapset.
351
352    """
353    bbox, mapnames, gisrc_src, gisrc_dst, cmd, groups = args
354    src, dst = get_mapset(gisrc_src, gisrc_dst)
355    env = os.environ.copy()
356    env['GISRC'] = gisrc_dst
357    shell = True if sys.platform == 'win32' else False
358    if mapnames:
359        inputs = dict(cmd['inputs'])
360        # reset the inputs to
361        for key in mapnames:
362            inputs[key] = mapnames[key]
363        cmd['inputs'] = inputs.items()
364        # set the region to the tile
365        sub.Popen(['g.region', 'raster=%s' % key], shell=shell, env=env).wait()
366    else:
367        # set the computational region
368        lcmd = ['g.region', ]
369        lcmd.extend(["%s=%s" % (k, v) for k, v in bbox.items()])
370        sub.Popen(lcmd, shell=shell, env=env).wait()
371    if groups:
372        copy_groups(groups, gisrc_src, gisrc_dst)
373    # run the grass command
374    sub.Popen(get_cmd(cmd), shell=shell, env=env).wait()
375    # remove temp GISRC
376    os.remove(gisrc_dst)
377
378
379class GridModule(object):
380    # TODO maybe also i.* could be supported easily
381    """Run GRASS raster commands in a multiprocessing mode.
382
383    :param cmd: raster GRASS command, only command staring with r.* are valid.
384    :type cmd: str
385    :param width: width of the tile, in pixel
386    :type width: int
387    :param height: height of the tile, in pixel.
388    :type height: int
389    :param overlap: overlap between tiles, in pixel.
390    :type overlap: int
391    :param processes: number of threads, default value is equal to the number
392                      of processor available.
393    :param split: if True use r.tile to split all the inputs.
394    :type split: bool
395    :param mapset_prefix: if specified created mapsets start with this prefix
396    :type mapset_prefix: str
397    :param run_: if False only instantiate the object
398    :type run_: bool
399    :param args: give all the parameters to the command
400    :param kargs: give all the parameters to the command
401
402    >>> grd = GridModule('r.slope.aspect',
403    ...                  width=500, height=500, overlap=2,
404    ...                  processes=None, split=False,
405    ...                  elevation='elevation',
406    ...                  slope='slope', aspect='aspect', overwrite=True)
407    >>> grd.run()
408    """
409    def __init__(self, cmd, width=None, height=None, overlap=0, processes=None,
410                 split=False, debug=False, region=None, move=None, log=False,
411                 start_row=0, start_col=0, out_prefix='', mapset_prefix=None,
412                 *args, **kargs):
413        kargs['run_'] = False
414        self.mset = Mapset()
415        self.module = Module(cmd, *args, **kargs)
416        self.width = width
417        self.height = height
418        self.overlap = overlap
419        self.processes = processes
420        self.region = region if region else Region()
421        self.start_row = start_row
422        self.start_col = start_col
423        self.out_prefix = out_prefix
424        self.log = log
425        self.move = move
426        self.gisrc_src = os.environ['GISRC']
427        self.n_mset, self.gisrc_dst = None, None
428        if self.move:
429            self.n_mset = copy_mapset(self.mset, self.move)
430            self.gisrc_dst = write_gisrc(self.n_mset.gisdbase,
431                                         self.n_mset.location,
432                                         self.n_mset.name)
433            rasters = [r for r in select(self.module.inputs, 'raster')]
434            if rasters:
435                copy_rasters(rasters, self.gisrc_src, self.gisrc_dst,
436                             region=self.region)
437            vectors = [v for v in select(self.module.inputs, 'vector')]
438            if vectors:
439                copy_vectors(vectors, self.gisrc_src, self.gisrc_dst)
440            groups = [g for g in select(self.module.inputs, 'group')]
441            if groups:
442                copy_groups(groups, self.gisrc_src, self.gisrc_dst,
443                            region=self.region)
444        self.bboxes = split_region_tiles(region=region,
445                                         width=width, height=height,
446                                         overlap=overlap)
447        if mapset_prefix:
448            self.msetstr = mapset_prefix + "_%03d_%03d"
449        else:
450            self.msetstr = cmd.replace('.', '') + "_%03d_%03d"
451        self.inlist = None
452        if split:
453            self.split()
454        self.debug = debug
455
456    def __del__(self):
457        if self.gisrc_dst:
458            # remove GISRC file
459            os.remove(self.gisrc_dst)
460
461    def clean_location(self, location=None):
462        """Remove all created mapsets.
463
464        :param location: a Location instance where we are running the analysis
465        :type location: Location object
466        """
467        if location is None:
468            if self.n_mset:
469                self.n_mset.current()
470            location = Location()
471
472        mapsets = location.mapsets(self.msetstr.split('_')[0] + '_*')
473        for mset in mapsets:
474            Mapset(mset).delete()
475        if self.n_mset and self.n_mset.is_current():
476            self.mset.current()
477
478    def split(self):
479        """Split all the raster inputs using r.tile"""
480        rtile = Module('r.tile')
481        inlist = {}
482        for inm in select(self.module.inputs, 'raster'):
483            rtile(input=inm.value, output=inm.value,
484                  width=self.width, height=self.height,
485                  overlap=self.overlap)
486            patt = '%s-*' % inm.value
487            inlist[inm.value] = sorted(self.mset.glist(type='raster',
488                                                       pattern=patt))
489        self.inlist = inlist
490
491    def get_works(self):
492        """Return a list of tuble with the parameters for cmd_exe function"""
493        works = []
494        reg = Region()
495        if self.move:
496            mdst, ldst, gdst = read_gisrc(self.gisrc_dst)
497        else:
498            ldst, gdst = self.mset.location, self.mset.gisdbase
499        cmd = self.module.get_dict()
500        groups = [g for g in select(self.module.inputs, 'group')]
501        for row, box_row in enumerate(self.bboxes):
502            for col, box in enumerate(box_row):
503                inms = None
504                if self.inlist:
505                    inms = {}
506                    cols = len(box_row)
507                    for key in self.inlist:
508                        indx = row * cols + col
509                        inms[key] = "%s@%s" % (self.inlist[key][indx],
510                                               self.mset.name)
511                # set the computational region, prepare the region parameters
512                bbox = dict([(k[0], str(v)) for k, v in box.items()[:-2]])
513                bbox['nsres'] = '%f' % reg.nsres
514                bbox['ewres'] = '%f' % reg.ewres
515                new_mset = self.msetstr % (self.start_row + row,
516                                           self.start_col + col),
517                works.append((bbox, inms,
518                              self.gisrc_src,
519                              write_gisrc(gdst, ldst, new_mset),
520                              cmd, groups))
521        return works
522
523    def define_mapset_inputs(self):
524        """Add the mapset information to the input maps
525        """
526        for inmap in self.module.inputs:
527            inm = self.module.inputs[inmap]
528            if inm.type in ('raster', 'vector') and inm.value:
529                if '@' not in inm.value:
530                    mset = get_mapset_raster(inm.value)
531                    inm.value = inm.value + '@%s' % mset
532
533    def run(self, patch=True, clean=True):
534        """Run the GRASS command
535
536        :param patch: set False if you does not want to patch the results
537        :type patch: bool
538        :param clean: set False if you does not want to remove all the stuff
539                      created by GridModule
540        :type clean: bool
541        """
542        self.module.flags.overwrite = True
543        self.define_mapset_inputs()
544        if self.debug:
545            for wrk in self.get_works():
546                cmd_exe(wrk)
547        else:
548            pool = mltp.Pool(processes=self.processes)
549            result = pool.map_async(cmd_exe, self.get_works())
550            result.wait()
551            pool.close()
552            pool.join()
553            if not result.successful():
554                raise RuntimeError(_("Execution of subprocesses was not successful"))
555
556        if patch:
557            if self.move:
558                os.environ['GISRC'] = self.gisrc_dst
559                self.n_mset.current()
560                self.patch()
561                os.environ['GISRC'] = self.gisrc_src
562                self.mset.current()
563                # copy the outputs from dst => src
564                routputs = [self.out_prefix + o
565                            for o in select(self.module.outputs, 'raster')]
566                copy_rasters(routputs, self.gisrc_dst, self.gisrc_src)
567            else:
568                self.patch()
569
570        if self.log:
571            # record in the temp directory
572            from grass.lib.gis import G_tempfile
573            tmp, dummy = os.path.split(G_tempfile())
574            tmpdir = os.path.join(tmp, self.module.name)
575            for k in self.module.outputs:
576                par = self.module.outputs[k]
577                if par.typedesc == 'raster' and par.value:
578                    dirpath = os.path.join(tmpdir, par.name)
579                    if not os.path.isdir(dirpath):
580                        os.makedirs(dirpath)
581                    fil = open(os.path.join(dirpath,
582                                            self.out_prefix + par.value), 'w+')
583                    fil.close()
584
585        if clean:
586            self.clean_location()
587            self.rm_tiles()
588            if self.n_mset:
589                gisdbase, location = os.path.split(self.move)
590                self.clean_location(Location(location, gisdbase))
591                # rm temporary gis_rc
592                os.remove(self.gisrc_dst)
593                self.gisrc_dst = None
594                sht.rmtree(os.path.join(self.move, 'PERMANENT'))
595                sht.rmtree(os.path.join(self.move, self.mset.name))
596
597    def patch(self):
598        """Patch the final results."""
599        bboxes = split_region_tiles(width=self.width, height=self.height)
600        loc = Location()
601        mset = loc[self.mset.name]
602        mset.visible.extend(loc.mapsets())
603        noutputs = 0
604        for otmap in self.module.outputs:
605            otm = self.module.outputs[otmap]
606            if otm.typedesc == 'raster' and otm.value:
607                rpatch_map(otm.value,
608                           self.mset.name, self.msetstr, bboxes,
609                           self.module.flags.overwrite,
610                           self.start_row, self.start_col, self.out_prefix)
611                noutputs += 1
612        if noutputs < 1:
613            msg = 'No raster output option defined for <{}>'.format(self.module.name)
614            if self.module.name == 'r.mapcalc':
615                msg += '. Use <{}.simple> instead'.format(self.module.name)
616            raise RuntimeError(msg)
617
618    def rm_tiles(self):
619        """Remove all the tiles."""
620        # if split, remove tiles
621        if self.inlist:
622            grm = Module('g.remove')
623            for key in self.inlist:
624                grm(flags='f', type='raster', name=self.inlist[key])
625