1# Copyright 2019 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""A helper for jobs that have been created on the Quantum Engine."""
15import datetime
16import time
17
18from typing import Dict, Iterator, List, Optional, overload, Tuple, TYPE_CHECKING
19
20import cirq
21from cirq_google.engine import calibration, engine_client
22from cirq_google.engine.calibration_result import CalibrationResult
23from cirq_google.engine.client import quantum
24from cirq_google.engine.result_type import ResultType
25from cirq_google.api import v1, v2
26
27if TYPE_CHECKING:
28    import datetime
29    import cirq_google.engine.engine as engine_base
30    from cirq_google.engine.engine import engine_program
31    from cirq_google.engine.engine import engine_processor
32
33TERMINAL_STATES = [
34    quantum.enums.ExecutionStatus.State.SUCCESS,
35    quantum.enums.ExecutionStatus.State.FAILURE,
36    quantum.enums.ExecutionStatus.State.CANCELLED,
37]
38
39
40class EngineJob:
41    """A job created via the Quantum Engine API.
42
43    This job may be in a variety of states. It may be scheduling, it may be
44    executing on a machine, or it may have entered a terminal state
45    (either succeeding or failing).
46
47    `EngineJob`s can be iterated over, returning `Result`s. These
48    `Result`s can also be accessed by index. Note that this will block
49    until the results are returned from the Engine service.
50
51    Attributes:
52      project_id: A project_id of the parent Google Cloud Project.
53      program_id: Unique ID of the program within the parent project.
54      job_id: Unique ID of the job within the parent program.
55    """
56
57    def __init__(
58        self,
59        project_id: str,
60        program_id: str,
61        job_id: str,
62        context: 'engine_base.EngineContext',
63        _job: Optional[quantum.types.QuantumJob] = None,
64        result_type: ResultType = ResultType.Program,
65    ) -> None:
66        """A job submitted to the engine.
67
68        Args:
69            project_id: A project_id of the parent Google Cloud Project.
70            program_id: Unique ID of the program within the parent project.
71            job_id: Unique ID of the job within the parent program.
72            context: Engine configuration and context to use.
73            _job: The optional current job state.
74            result_type: What type of results are expected, such as
75               batched results or the result of a focused calibration.
76        """
77        self.project_id = project_id
78        self.program_id = program_id
79        self.job_id = job_id
80        self.context = context
81        self._job = _job
82        self._results: Optional[List[cirq.Result]] = None
83        self._calibration_results: Optional[CalibrationResult] = None
84        self._batched_results: Optional[List[List[cirq.Result]]] = None
85        self.result_type = result_type
86
87    def engine(self) -> 'engine_base.Engine':
88        """Returns the parent Engine object."""
89        import cirq_google.engine.engine as engine_base
90
91        return engine_base.Engine(self.project_id, context=self.context)
92
93    def program(self) -> 'engine_program.EngineProgram':
94        """Returns the parent EngineProgram object."""
95        import cirq_google.engine.engine_program as engine_program
96
97        return engine_program.EngineProgram(self.project_id, self.program_id, self.context)
98
99    def _inner_job(self) -> quantum.types.QuantumJob:
100        if not self._job:
101            self._job = self.context.client.get_job(
102                self.project_id, self.program_id, self.job_id, False
103            )
104        return self._job
105
106    def _refresh_job(self) -> quantum.types.QuantumJob:
107        if not self._job or self._job.execution_status.state not in TERMINAL_STATES:
108            self._job = self.context.client.get_job(
109                self.project_id, self.program_id, self.job_id, False
110            )
111        return self._job
112
113    def create_time(self) -> 'datetime.datetime':
114        """Returns when the job was created."""
115        return self._inner_job().create_time.ToDatetime()
116
117    def update_time(self) -> 'datetime.datetime':
118        """Returns when the job was last updated."""
119        self._job = self.context.client.get_job(
120            self.project_id, self.program_id, self.job_id, False
121        )
122        return self._job.update_time.ToDatetime()
123
124    def description(self) -> str:
125        """Returns the description of the job."""
126        return self._inner_job().description
127
128    def set_description(self, description: str) -> 'EngineJob':
129        """Sets the description of the job.
130
131        Params:
132            description: The new description for the job.
133
134        Returns:
135             This EngineJob.
136        """
137        self._job = self.context.client.set_job_description(
138            self.project_id, self.program_id, self.job_id, description
139        )
140        return self
141
142    def labels(self) -> Dict[str, str]:
143        """Returns the labels of the job."""
144        return self._inner_job().labels
145
146    def set_labels(self, labels: Dict[str, str]) -> 'EngineJob':
147        """Sets (overwriting) the labels for a previously created quantum job.
148
149        Params:
150            labels: The entire set of new job labels.
151
152        Returns:
153             This EngineJob.
154        """
155        self._job = self.context.client.set_job_labels(
156            self.project_id, self.program_id, self.job_id, labels
157        )
158        return self
159
160    def add_labels(self, labels: Dict[str, str]) -> 'EngineJob':
161        """Adds new labels to a previously created quantum job.
162
163        Params:
164            labels: New labels to add to the existing job labels.
165
166        Returns:
167             This EngineJob.
168        """
169        self._job = self.context.client.add_job_labels(
170            self.project_id, self.program_id, self.job_id, labels
171        )
172        return self
173
174    def remove_labels(self, keys: List[str]) -> 'EngineJob':
175        """Removes labels with given keys from the labels of a previously
176        created quantum job.
177
178        Params:
179            label_keys: Label keys to remove from the existing job labels.
180
181        Returns:
182            This EngineJob.
183        """
184        self._job = self.context.client.remove_job_labels(
185            self.project_id, self.program_id, self.job_id, keys
186        )
187        return self
188
189    def processor_ids(self) -> List[str]:
190        """Returns the processor ids provided when the job was created."""
191        return [
192            engine_client._ids_from_processor_name(p)[1]
193            for p in self._inner_job().scheduling_config.processor_selector.processor_names
194        ]
195
196    def status(self) -> str:
197        """Return the execution status of the job."""
198        return quantum.types.ExecutionStatus.State.Name(self._refresh_job().execution_status.state)
199
200    def failure(self) -> Optional[Tuple[str, str]]:
201        """Return failure code and message of the job if present."""
202        if self._inner_job().execution_status.HasField('failure'):
203            failure = self._inner_job().execution_status.failure
204            return (
205                quantum.types.ExecutionStatus.Failure.Code.Name(failure.error_code),
206                failure.error_message,
207            )
208        return None
209
210    def get_repetitions_and_sweeps(self) -> Tuple[int, List[cirq.Sweep]]:
211        """Returns the repetitions and sweeps for the Quantum Engine job.
212
213        Returns:
214            A tuple of the repetition count and list of sweeps.
215        """
216        if not self._job or not self._job.HasField('run_context'):
217            self._job = self.context.client.get_job(
218                self.project_id, self.program_id, self.job_id, True
219            )
220
221        return _deserialize_run_context(self._job.run_context)
222
223    def get_processor(self) -> 'Optional[engine_processor.EngineProcessor]':
224        """Returns the EngineProcessor for the processor the job is/was run on,
225        if available, else None."""
226        status = self._inner_job().execution_status
227        if not status.processor_name:
228            return None
229        import cirq_google.engine.engine_processor as engine_processor
230
231        ids = engine_client._ids_from_processor_name(status.processor_name)
232        return engine_processor.EngineProcessor(ids[0], ids[1], self.context)
233
234    def get_calibration(self) -> Optional[calibration.Calibration]:
235        """Returns the recorded calibration at the time when the job was run, if
236        one was captured, else None."""
237        status = self._inner_job().execution_status
238        if not status.calibration_name:
239            return None
240        ids = engine_client._ids_from_calibration_name(status.calibration_name)
241        response = self.context.client.get_calibration(*ids)
242        metrics = v2.metrics_pb2.MetricsSnapshot.FromString(response.data.value)
243        return calibration.Calibration(metrics)
244
245    def cancel(self) -> None:
246        """Cancel the job."""
247        self.context.client.cancel_job(self.project_id, self.program_id, self.job_id)
248
249    def delete(self) -> None:
250        """Deletes the job and result, if any."""
251        self.context.client.delete_job(self.project_id, self.program_id, self.job_id)
252
253    def batched_results(self) -> List[List[cirq.Result]]:
254        """Returns the job results, blocking until the job is complete.
255
256        This method is intended for batched jobs.  Instead of flattening
257        results into a single list, this will return a List[Result]
258        for each circuit in the batch.
259        """
260        self.results()
261        if not self._batched_results:
262            raise ValueError('batched_results called for a non-batch result.')
263        return self._batched_results
264
265    def _wait_for_result(self):
266        job = self._refresh_job()
267        total_seconds_waited = 0.0
268        timeout = self.context.timeout
269        while True:
270            if timeout and total_seconds_waited >= timeout:
271                break
272            if job.execution_status.state in TERMINAL_STATES:
273                break
274            time.sleep(0.5)
275            total_seconds_waited += 0.5
276            job = self._refresh_job()
277        _raise_on_failure(job)
278        response = self.context.client.get_job_results(
279            self.project_id, self.program_id, self.job_id
280        )
281        return response.result
282
283    def results(self) -> List[cirq.Result]:
284        """Returns the job results, blocking until the job is complete."""
285        import cirq_google.engine.engine as engine_base
286
287        if not self._results:
288            result = self._wait_for_result()
289            result_type = result.type_url[len(engine_base.TYPE_PREFIX) :]
290            if (
291                result_type == 'cirq.google.api.v1.Result'
292                or result_type == 'cirq.api.google.v1.Result'
293            ):
294                v1_parsed_result = v1.program_pb2.Result.FromString(result.value)
295                self._results = _get_job_results_v1(v1_parsed_result)
296            elif (
297                result_type == 'cirq.google.api.v2.Result'
298                or result_type == 'cirq.api.google.v2.Result'
299            ):
300                v2_parsed_result = v2.result_pb2.Result.FromString(result.value)
301                self._results = _get_job_results_v2(v2_parsed_result)
302            elif result.Is(v2.batch_pb2.BatchResult.DESCRIPTOR):
303                v2_parsed_result = v2.batch_pb2.BatchResult.FromString(result.value)
304                self._batched_results = self._get_batch_results_v2(v2_parsed_result)
305                self._results = self._flatten(self._batched_results)
306            else:
307                raise ValueError(f'invalid result proto version: {result_type}')
308        return self._results
309
310    def calibration_results(self):
311        """Returns the results of a run_calibration() call.
312
313        This function will fail if any other type of results were returned
314        by the Engine.
315        """
316        import cirq_google.engine.engine as engine_base
317
318        if not self._calibration_results:
319            result = self._wait_for_result()
320            result_type = result.type_url[len(engine_base.TYPE_PREFIX) :]
321            if result_type != 'cirq.google.api.v2.FocusedCalibrationResult':
322                raise ValueError(f'Did not find calibration results, instead found: {result_type}')
323            parsed_val = v2.calibration_pb2.FocusedCalibrationResult.FromString(result.value)
324            cal_results = []
325            for layer in parsed_val.results:
326                metrics = calibration.Calibration(layer.metrics)
327                message = layer.error_message or None
328                token = layer.token or None
329                if layer.valid_until_ms > 0:
330                    ts = datetime.datetime.fromtimestamp(layer.valid_until_ms / 1000)
331                else:
332                    ts = None
333                cal_results.append(CalibrationResult(layer.code, message, token, ts, metrics))
334            self._calibration_results = cal_results
335        return self._calibration_results
336
337    @classmethod
338    def _get_batch_results_v2(cls, results: v2.batch_pb2.BatchResult) -> List[List[cirq.Result]]:
339        trial_results = []
340        for result in results.results:
341            # Add a new list for the result
342            trial_results.append(_get_job_results_v2(result))
343        return trial_results
344
345    @classmethod
346    def _flatten(cls, result) -> List[cirq.Result]:
347        return [res for result_list in result for res in result_list]
348
349    def __iter__(self) -> Iterator[cirq.Result]:
350        return iter(self.results())
351
352    # pylint: disable=function-redefined
353    @overload
354    def __getitem__(self, item: int) -> cirq.Result:
355        pass
356
357    @overload
358    def __getitem__(self, item: slice) -> List[cirq.Result]:
359        pass
360
361    def __getitem__(self, item):
362        return self.results()[item]
363
364    # pylint: enable=function-redefined
365
366    def __len__(self) -> int:
367        return len(self.results())
368
369    def __str__(self) -> str:
370        return (
371            f'EngineJob(project_id=\'{self.project_id}\', '
372            f'program_id=\'{self.program_id}\', job_id=\'{self.job_id}\')'
373        )
374
375
376def _deserialize_run_context(
377    run_context: quantum.types.any_pb2.Any,
378) -> Tuple[int, List[cirq.Sweep]]:
379    import cirq_google.engine.engine as engine_base
380
381    run_context_type = run_context.type_url[len(engine_base.TYPE_PREFIX) :]
382    if (
383        run_context_type == 'cirq.google.api.v1.RunContext'
384        or run_context_type == 'cirq.api.google.v1.RunContext'
385    ):
386        raise ValueError('deserializing a v1 RunContext is not supported')
387    if (
388        run_context_type == 'cirq.google.api.v2.RunContext'
389        or run_context_type == 'cirq.api.google.v2.RunContext'
390    ):
391        v2_run_context = v2.run_context_pb2.RunContext.FromString(run_context.value)
392        return v2_run_context.parameter_sweeps[0].repetitions, [
393            v2.sweep_from_proto(s.sweep) for s in v2_run_context.parameter_sweeps
394        ]
395    raise ValueError(f'unsupported run_context type: {run_context_type}')
396
397
398def _get_job_results_v1(result: v1.program_pb2.Result) -> List[cirq.Result]:
399    trial_results = []
400    for sweep_result in result.sweep_results:
401        sweep_repetitions = sweep_result.repetitions
402        key_sizes = [(m.key, len(m.qubits)) for m in sweep_result.measurement_keys]
403        for result in sweep_result.parameterized_results:
404            data = result.measurement_results
405            measurements = v1.unpack_results(data, sweep_repetitions, key_sizes)
406
407            trial_results.append(
408                cirq.Result.from_single_parameter_set(
409                    params=cirq.ParamResolver(result.params.assignments),
410                    measurements=measurements,
411                )
412            )
413    return trial_results
414
415
416def _get_job_results_v2(result: v2.result_pb2.Result) -> List[cirq.Result]:
417    sweep_results = v2.results_from_proto(result)
418    # Flatten to single list to match to sampler api.
419    return [trial_result for sweep_result in sweep_results for trial_result in sweep_result]
420
421
422def _raise_on_failure(job: quantum.types.QuantumJob) -> None:
423    execution_status = job.execution_status
424    state = execution_status.state
425    name = job.name
426    if state != quantum.enums.ExecutionStatus.State.SUCCESS:
427        if state == quantum.enums.ExecutionStatus.State.FAILURE:
428            processor = execution_status.processor_name or 'UNKNOWN'
429            error_code = execution_status.failure.error_code
430            error_message = execution_status.failure.error_message
431            raise RuntimeError(
432                "Job {} on processor {} failed. {}: {}".format(
433                    name,
434                    processor,
435                    quantum.types.ExecutionStatus.Failure.Code.Name(error_code),
436                    error_message,
437                )
438            )
439        elif state in TERMINAL_STATES:
440            raise RuntimeError(
441                'Job {} failed in state {}.'.format(
442                    name,
443                    quantum.types.ExecutionStatus.State.Name(state),
444                )
445            )
446        else:
447            raise RuntimeError(
448                'Timed out waiting for results. Job {} is in state {}'.format(
449                    name, quantum.types.ExecutionStatus.State.Name(state)
450                )
451            )
452