1import logging
2from saml2.attribute_resolver import AttributeResolver
3from saml2.saml import NAMEID_FORMAT_PERSISTENT
4
5logger = logging.getLogger(__name__)
6
7
8class VirtualOrg(object):
9    def __init__(self, sp, vorg, cnf):
10        self.sp = sp  # The parent SP client instance
11        self._name = vorg
12        self.common_identifier = cnf["common_identifier"]
13        try:
14            self.member = cnf["member"]
15        except KeyError:
16            self.member = []
17        try:
18            self.nameid_format = cnf["nameid_format"]
19        except KeyError:
20            self.nameid_format = NAMEID_FORMAT_PERSISTENT
21
22    def _cache_session(self, session_info):
23        return True
24
25    def _affiliation_members(self):
26        """
27        Get the member of the Virtual Organization from the metadata,
28        more specifically from AffiliationDescriptor.
29        """
30        return self.sp.config.metadata.vo_members(self._name)
31
32    def members_to_ask(self, name_id):
33        """Find the member of the Virtual Organization that I haven't already
34        spoken too
35        """
36
37        vo_members = self._affiliation_members()
38        for member in self.member:
39            if member not in vo_members:
40                vo_members.append(member)
41
42        # Remove the ones I have cached data from about this subject
43        vo_members = [m for m in vo_members if not self.sp.users.cache.active(
44            name_id, m)]
45        logger.info("VO members (not cached): %s", vo_members)
46        return vo_members
47
48    def get_common_identifier(self, name_id):
49        (ava, _) = self.sp.users.get_identity(name_id)
50        if ava == {}:
51            return None
52
53        ident = self.common_identifier
54
55        try:
56            return ava[ident][0]
57        except KeyError:
58            return None
59
60    def do_aggregation(self, name_id):
61
62        logger.info("** Do VO aggregation **\nSubjectID: %s, VO:%s",
63            name_id, self._name)
64
65        to_ask = self.members_to_ask(name_id)
66        if to_ask:
67            com_identifier = self.get_common_identifier(name_id)
68
69            resolver = AttributeResolver(self.sp)
70            # extends returns a list of session_infos
71            for session_info in resolver.extend(
72                    com_identifier, self.sp.config.entityid, to_ask):
73                _ = self._cache_session(session_info)
74
75            logger.info(">Issuers: %s", self.sp.users.issuers_of_info(name_id))
76            logger.info("AVA: %s", self.sp.users.get_identity(name_id))
77
78            return True
79        else:
80            return False
81