1# -*- coding: utf-8 -*-
2"""
3ldap0.extop - support classes for LDAPv3 extended operations
4"""
5
6from typing import Dict, Union, Optional
7
8
9class ExtendedRequest:
10    """
11    Generic base class for a LDAPv3 extended operation request
12
13    requestName
14        OID as string of the LDAPv3 extended operation request
15    requestValue
16        value of the LDAPv3 extended operation request
17        (here it is the BER-encoded ASN.1 request value)
18    """
19    __slots__ = (
20        'requestName',
21        'requestValue',
22    )
23    encoding: str = 'utf-8'
24
25    defaultIntermediateResponse: Optional['IntermediateResponse'] = None
26
27    def __init__(self, requestName=None, requestValue=None):
28        if requestName is not None:
29            self.requestName = requestName
30        self.requestValue = requestValue
31
32    def __repr__(self):
33        return '%s(requestName=%r, requestValue=%r)' % (
34            self.__class__.__name__,
35            self.requestName,
36            self.requestValue,
37        )
38
39    def encode(self) -> bytes:
40        """
41        returns the BER-encoded ASN.1 request value composed by class attributes
42        set before
43        """
44        return self.requestValue
45
46
47class ExtendedResponse:
48    """
49    Generic base class for a LDAPv3 extended operation response
50
51    requestName
52        OID as string of the LDAPv3 extended operation response
53    encodedResponseValue
54        BER-encoded ASN.1 value of the LDAPv3 extended operation response
55    """
56    __slots__ = (
57        'responseValue',
58    )
59    encoding = 'utf-8'
60
61    responseName = None
62
63    def __init__(self, encodedResponseValue=None):
64        self.responseValue = encodedResponseValue
65        if encodedResponseValue is not None:
66            self.decode(encodedResponseValue)
67
68    def __repr__(self):
69        return '%s(encodedResponseValue=%r)' % (
70            self.__class__.__name__,
71            self.responseValue,
72        )
73
74    @classmethod
75    def check_resp_name(cls, name):
76        """
77        returns True if :name: is the correct expected responseName
78        """
79        return name == cls.responseName
80
81    def decode(self, value: bytes):
82        """
83        decodes the BER-encoded ASN.1 extended operation response value and
84        sets the appropriate class attributes
85        """
86        self.responseValue = value
87
88
89class IntermediateResponse:
90    """
91    Generic base class for a LDAPv3 intermediate response
92
93    requestName
94        OID as string of the LDAPv3 intermediate response
95    encodedResponseValue
96        BER-encoded ASN.1 value of the LDAPv3 intermediate response
97    """
98    encoding = 'utf-8'
99
100    responseName = None
101
102    def __init__(self, responseName=None, encodedResponseValue=None, ctrls=None):
103        if responseName is not None:
104            self.responseName = responseName
105        self.responseValue = encodedResponseValue
106        if encodedResponseValue is not None:
107            self.decode(encodedResponseValue)
108        self.ctrls = ctrls or []
109
110    def __repr__(self):
111        return '%s(responseName=%r, responseValue=%r)' % (
112            self.__class__.__name__,
113            self.responseName,
114            self.responseValue,
115        )
116
117    @classmethod
118    def check_resp_name(cls, name):
119        """
120        returns True if :name: is the correct expected responseName
121        """
122        return name == cls.responseName
123
124    def decode(self, value: bytes):
125        """
126        decodes the BER-encoded ASN.1 extended operation response value and
127        sets the appropriate class attributes
128        """
129        self.responseValue = value
130
131
132class ExtOpResponseRegistry:
133    """
134    A simple registry for responses and their handler class
135    """
136
137    _handler_cls: Dict[str, Union[ExtendedResponse, IntermediateResponse]]
138
139    def __init__(self):
140        self._handler_cls = {}
141
142    def register(
143            self,
144            handler: Union[ExtendedResponse, IntermediateResponse],
145        ):
146        """
147        register an handler class for extended/intermediate response
148        """
149        self._handler_cls[handler.responseName] = handler
150
151    def get(
152            self,
153            oid: str,
154            default: Optional[Union[ExtendedResponse, IntermediateResponse]] = None,
155        ) -> Optional[Union[ExtendedResponse, IntermediateResponse]]:
156        """
157        return handler class for extended/intermediate response by OID (responseName)
158        """
159        return self._handler_cls.get(oid, default)
160
161
162# response OID to class registry
163EXTOP_RESPONSE_REGISTRY = ExtOpResponseRegistry()
164INTERMEDIATE_RESPONSE_REGISTRY = ExtOpResponseRegistry()
165