1# Copyright 2019 The Cirq Developers
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
7#     https://www.apache.org/licenses/LICENSE-2.0
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""A helper for jobs that have been created on the Quantum Engine."""
15import datetime
16import time
18from typing import Dict, Iterator, List, Optional, overload, Tuple, TYPE_CHECKING
20import cirq
21from cirq_google.engine import calibration, engine_client
22from cirq_google.engine.calibration_result import CalibrationResult
23from cirq_google.engine.client import quantum
24from cirq_google.engine.result_type import ResultType
25from cirq_google.api import v1, v2
28    import datetime
29    import cirq_google.engine.engine as engine_base
30    from cirq_google.engine.engine import engine_program
31    from cirq_google.engine.engine import engine_processor
34    quantum.enums.ExecutionStatus.State.SUCCESS,
35    quantum.enums.ExecutionStatus.State.FAILURE,
36    quantum.enums.ExecutionStatus.State.CANCELLED,
40class EngineJob:
41    """A job created via the Quantum Engine API.
43    This job may be in a variety of states. It may be scheduling, it may be
44    executing on a machine, or it may have entered a terminal state
45    (either succeeding or failing).
47    `EngineJob`s can be iterated over, returning `Result`s. These
48    `Result`s can also be accessed by index. Note that this will block
49    until the results are returned from the Engine service.
51    Attributes:
52      project_id: A project_id of the parent Google Cloud Project.
53      program_id: Unique ID of the program within the parent project.
54      job_id: Unique ID of the job within the parent program.
55    """
57    def __init__(
58        self,
59        project_id: str,
60        program_id: str,
61        job_id: str,
62        context: 'engine_base.EngineContext',
63        _job: Optional[quantum.types.QuantumJob] = None,
64        result_type: ResultType = ResultType.Program,
65    ) -> None:
66        """A job submitted to the engine.
68        Args:
69            project_id: A project_id of the parent Google Cloud Project.
70            program_id: Unique ID of the program within the parent project.
71            job_id: Unique ID of the job within the parent program.
72            context: Engine configuration and context to use.
73            _job: The optional current job state.
74            result_type: What type of results are expected, such as
75               batched results or the result of a focused calibration.
76        """
77        self.project_id = project_id
78        self.program_id = program_id
79        self.job_id = job_id
80        self.context = context
81        self._job = _job
82        self._results: Optional[List[cirq.Result]] = None
83        self._calibration_results: Optional[CalibrationResult] = None
84        self._batched_results: Optional[List[List[cirq.Result]]] = None
85        self.result_type = result_type
87    def engine(self) -> 'engine_base.Engine':
88        """Returns the parent Engine object."""
89        import cirq_google.engine.engine as engine_base
91        return engine_base.Engine(self.project_id, context=self.context)
93    def program(self) -> 'engine_program.EngineProgram':
94        """Returns the parent EngineProgram object."""
95        import cirq_google.engine.engine_program as engine_program
97        return engine_program.EngineProgram(self.project_id, self.program_id, self.context)
99    def _inner_job(self) -> quantum.types.QuantumJob:
100        if not self._job:
101            self._job = self.context.client.get_job(
102                self.project_id, self.program_id, self.job_id, False
103            )
104        return self._job
106    def _refresh_job(self) -> quantum.types.QuantumJob:
107        if not self._job or self._job.execution_status.state not in TERMINAL_STATES:
108            self._job = self.context.client.get_job(
109                self.project_id, self.program_id, self.job_id, False
110            )
111        return self._job
113    def create_time(self) -> 'datetime.datetime':
114        """Returns when the job was created."""
115        return self._inner_job().create_time.ToDatetime()
117    def update_time(self) -> 'datetime.datetime':
118        """Returns when the job was last updated."""
119        self._job = self.context.client.get_job(
120            self.project_id, self.program_id, self.job_id, False
121        )
122        return self._job.update_time.ToDatetime()
124    def description(self) -> str:
125        """Returns the description of the job."""
126        return self._inner_job().description
128    def set_description(self, description: str) -> 'EngineJob':
129        """Sets the description of the job.
131        Params:
132            description: The new description for the job.
134        Returns:
135             This EngineJob.
136        """
137        self._job = self.context.client.set_job_description(
138            self.project_id, self.program_id, self.job_id, description
139        )
140        return self
142    def labels(self) -> Dict[str, str]:
143        """Returns the labels of the job."""
144        return self._inner_job().labels
146    def set_labels(self, labels: Dict[str, str]) -> 'EngineJob':
147        """Sets (overwriting) the labels for a previously created quantum job.
149        Params:
150            labels: The entire set of new job labels.
152        Returns:
153             This EngineJob.
154        """
155        self._job = self.context.client.set_job_labels(
156            self.project_id, self.program_id, self.job_id, labels
157        )
158        return self
160    def add_labels(self, labels: Dict[str, str]) -> 'EngineJob':
161        """Adds new labels to a previously created quantum job.
163        Params:
164            labels: New labels to add to the existing job labels.
166        Returns:
167             This EngineJob.
168        """
169        self._job = self.context.client.add_job_labels(
170            self.project_id, self.program_id, self.job_id, labels
171        )
172        return self
174    def remove_labels(self, keys: List[str]) -> 'EngineJob':
175        """Removes labels with given keys from the labels of a previously
176        created quantum job.
178        Params:
179            label_keys: Label keys to remove from the existing job labels.
181        Returns:
182            This EngineJob.
183        """
184        self._job = self.context.client.remove_job_labels(
185            self.project_id, self.program_id, self.job_id, keys
186        )
187        return self
189    def processor_ids(self) -> List[str]:
190        """Returns the processor ids provided when the job was created."""
191        return [
192            engine_client._ids_from_processor_name(p)[1]
193            for p in self._inner_job().scheduling_config.processor_selector.processor_names
194        ]
196    def status(self) -> str:
197        """Return the execution status of the job."""
198        return quantum.types.ExecutionStatus.State.Name(self._refresh_job().execution_status.state)
200    def failure(self) -> Optional[Tuple[str, str]]:
201        """Return failure code and message of the job if present."""
202        if self._inner_job().execution_status.HasField('failure'):
203            failure = self._inner_job().execution_status.failure
204            return (
205                quantum.types.ExecutionStatus.Failure.Code.Name(failure.error_code),
206                failure.error_message,
207            )
208        return None
210    def get_repetitions_and_sweeps(self) -> Tuple[int, List[cirq.Sweep]]:
211        """Returns the repetitions and sweeps for the Quantum Engine job.
213        Returns:
214            A tuple of the repetition count and list of sweeps.
215        """
216        if not self._job or not self._job.HasField('run_context'):
217            self._job = self.context.client.get_job(
218                self.project_id, self.program_id, self.job_id, True
219            )
221        return _deserialize_run_context(self._job.run_context)
223    def get_processor(self) -> 'Optional[engine_processor.EngineProcessor]':
224        """Returns the EngineProcessor for the processor the job is/was run on,
225        if available, else None."""
226        status = self._inner_job().execution_status
227        if not status.processor_name:
228            return None
229        import cirq_google.engine.engine_processor as engine_processor
231        ids = engine_client._ids_from_processor_name(status.processor_name)
232        return engine_processor.EngineProcessor(ids[0], ids[1], self.context)
234    def get_calibration(self) -> Optional[calibration.Calibration]:
235        """Returns the recorded calibration at the time when the job was run, if
236        one was captured, else None."""
237        status = self._inner_job().execution_status
238        if not status.calibration_name:
239            return None
240        ids = engine_client._ids_from_calibration_name(status.calibration_name)
241        response = self.context.client.get_calibration(*ids)
242        metrics = v2.metrics_pb2.MetricsSnapshot.FromString(response.data.value)
243        return calibration.Calibration(metrics)
245    def cancel(self) -> None:
246        """Cancel the job."""
247        self.context.client.cancel_job(self.project_id, self.program_id, self.job_id)
249    def delete(self) -> None:
250        """Deletes the job and result, if any."""
251        self.context.client.delete_job(self.project_id, self.program_id, self.job_id)
253    def batched_results(self) -> List[List[cirq.Result]]:
254        """Returns the job results, blocking until the job is complete.
256        This method is intended for batched jobs.  Instead of flattening
257        results into a single list, this will return a List[Result]
258        for each circuit in the batch.
259        """
260        self.results()
261        if not self._batched_results:
262            raise ValueError('batched_results called for a non-batch result.')
263        return self._batched_results
265    def _wait_for_result(self):
266        job = self._refresh_job()
267        total_seconds_waited = 0.0
268        timeout = self.context.timeout
269        while True:
270            if timeout and total_seconds_waited >= timeout:
271                break
272            if job.execution_status.state in TERMINAL_STATES:
273                break
274            time.sleep(0.5)
275            total_seconds_waited += 0.5
276            job = self._refresh_job()
277        _raise_on_failure(job)
278        response = self.context.client.get_job_results(
279            self.project_id, self.program_id, self.job_id
280        )
281        return response.result
283    def results(self) -> List[cirq.Result]:
284        """Returns the job results, blocking until the job is complete."""
285        import cirq_google.engine.engine as engine_base
287        if not self._results:
288            result = self._wait_for_result()
289            result_type = result.type_url[len(engine_base.TYPE_PREFIX) :]
290            if (
291                result_type == 'cirq.google.api.v1.Result'
292                or result_type == 'cirq.api.google.v1.Result'
293            ):
294                v1_parsed_result = v1.program_pb2.Result.FromString(result.value)
295                self._results = _get_job_results_v1(v1_parsed_result)
296            elif (
297                result_type == 'cirq.google.api.v2.Result'
298                or result_type == 'cirq.api.google.v2.Result'
299            ):
300                v2_parsed_result = v2.result_pb2.Result.FromString(result.value)
301                self._results = _get_job_results_v2(v2_parsed_result)
302            elif result.Is(v2.batch_pb2.BatchResult.DESCRIPTOR):
303                v2_parsed_result = v2.batch_pb2.BatchResult.FromString(result.value)
304                self._batched_results = self._get_batch_results_v2(v2_parsed_result)
305                self._results = self._flatten(self._batched_results)
306            else:
307                raise ValueError(f'invalid result proto version: {result_type}')
308        return self._results
310    def calibration_results(self):
311        """Returns the results of a run_calibration() call.
313        This function will fail if any other type of results were returned
314        by the Engine.
315        """
316        import cirq_google.engine.engine as engine_base
318        if not self._calibration_results:
319            result = self._wait_for_result()
320            result_type = result.type_url[len(engine_base.TYPE_PREFIX) :]
321            if result_type != 'cirq.google.api.v2.FocusedCalibrationResult':
322                raise ValueError(f'Did not find calibration results, instead found: {result_type}')
323            parsed_val = v2.calibration_pb2.FocusedCalibrationResult.FromString(result.value)
324            cal_results = []
325            for layer in parsed_val.results:
326                metrics = calibration.Calibration(layer.metrics)
327                message = layer.error_message or None
328                token = layer.token or None
329                if layer.valid_until_ms > 0:
330                    ts = datetime.datetime.fromtimestamp(layer.valid_until_ms / 1000)
331                else:
332                    ts = None
333                cal_results.append(CalibrationResult(layer.code, message, token, ts, metrics))
334            self._calibration_results = cal_results
335        return self._calibration_results
337    @classmethod
338    def _get_batch_results_v2(cls, results: v2.batch_pb2.BatchResult) -> List[List[cirq.Result]]:
339        trial_results = []
340        for result in results.results:
341            # Add a new list for the result
342            trial_results.append(_get_job_results_v2(result))
343        return trial_results
345    @classmethod
346    def _flatten(cls, result) -> List[cirq.Result]:
347        return [res for result_list in result for res in result_list]
349    def __iter__(self) -> Iterator[cirq.Result]:
350        return iter(self.results())
352    # pylint: disable=function-redefined
353    @overload
354    def __getitem__(self, item: int) -> cirq.Result:
355        pass
357    @overload
358    def __getitem__(self, item: slice) -> List[cirq.Result]:
359        pass
361    def __getitem__(self, item):
362        return self.results()[item]
364    # pylint: enable=function-redefined
366    def __len__(self) -> int:
367        return len(self.results())
369    def __str__(self) -> str:
370        return (
371            f'EngineJob(project_id=\'{self.project_id}\', '
372            f'program_id=\'{self.program_id}\', job_id=\'{self.job_id}\')'
373        )
376def _deserialize_run_context(
377    run_context: quantum.types.any_pb2.Any,
378) -> Tuple[int, List[cirq.Sweep]]:
379    import cirq_google.engine.engine as engine_base
381    run_context_type = run_context.type_url[len(engine_base.TYPE_PREFIX) :]
382    if (
383        run_context_type == 'cirq.google.api.v1.RunContext'
384        or run_context_type == 'cirq.api.google.v1.RunContext'
385    ):
386        raise ValueError('deserializing a v1 RunContext is not supported')
387    if (
388        run_context_type == 'cirq.google.api.v2.RunContext'
389        or run_context_type == 'cirq.api.google.v2.RunContext'
390    ):
391        v2_run_context = v2.run_context_pb2.RunContext.FromString(run_context.value)
392        return v2_run_context.parameter_sweeps[0].repetitions, [
393            v2.sweep_from_proto(s.sweep) for s in v2_run_context.parameter_sweeps
394        ]
395    raise ValueError(f'unsupported run_context type: {run_context_type}')
398def _get_job_results_v1(result: v1.program_pb2.Result) -> List[cirq.Result]:
399    trial_results = []
400    for sweep_result in result.sweep_results:
401        sweep_repetitions = sweep_result.repetitions
402        key_sizes = [(m.key, len(m.qubits)) for m in sweep_result.measurement_keys]
403        for result in sweep_result.parameterized_results:
404            data = result.measurement_results
405            measurements = v1.unpack_results(data, sweep_repetitions, key_sizes)
407            trial_results.append(
408                cirq.Result.from_single_parameter_set(
409                    params=cirq.ParamResolver(result.params.assignments),
410                    measurements=measurements,
411                )
412            )
413    return trial_results
416def _get_job_results_v2(result: v2.result_pb2.Result) -> List[cirq.Result]:
417    sweep_results = v2.results_from_proto(result)
418    # Flatten to single list to match to sampler api.
419    return [trial_result for sweep_result in sweep_results for trial_result in sweep_result]
422def _raise_on_failure(job: quantum.types.QuantumJob) -> None:
423    execution_status = job.execution_status
424    state = execution_status.state
425    name = job.name
426    if state != quantum.enums.ExecutionStatus.State.SUCCESS:
427        if state == quantum.enums.ExecutionStatus.State.FAILURE:
428            processor = execution_status.processor_name or 'UNKNOWN'
429            error_code = execution_status.failure.error_code
430            error_message = execution_status.failure.error_message
431            raise RuntimeError(
432                "Job {} on processor {} failed. {}: {}".format(
433                    name,
434                    processor,
435                    quantum.types.ExecutionStatus.Failure.Code.Name(error_code),
436                    error_message,
437                )
438            )
439        elif state in TERMINAL_STATES:
440            raise RuntimeError(
441                'Job {} failed in state {}.'.format(
442                    name,
443                    quantum.types.ExecutionStatus.State.Name(state),
444                )
445            )
446        else:
447            raise RuntimeError(
448                'Timed out waiting for results. Job {} is in state {}'.format(
449                    name, quantum.types.ExecutionStatus.State.Name(state)
450                )
451            )