1# Copyright 2019 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import datetime
16from unittest import mock
17import pytest
18
19from google.protobuf.text_format import Merge
20
21import cirq
22import cirq_google as cg
23from cirq_google.api import v1, v2
24from cirq_google.engine.client.quantum_v1alpha1 import types as qtypes
25from cirq_google.engine.engine import EngineContext
26
27
28def _to_any(proto):
29    any_proto = qtypes.any_pb2.Any()
30    any_proto.Pack(proto)
31    return any_proto
32
33
34@pytest.fixture(scope='session', autouse=True)
35def mock_grpc_client():
36    with mock.patch(
37        'cirq_google.engine.engine_client.quantum.QuantumEngineServiceClient'
38    ) as _fixture:
39        yield _fixture
40
41
42def test_engine():
43    job = cg.EngineJob('a', 'b', 'steve', EngineContext())
44    assert job.engine().project_id == 'a'
45
46
47def test_program():
48    job = cg.EngineJob('a', 'b', 'steve', EngineContext())
49    assert job.program().project_id == 'a'
50    assert job.program().program_id == 'b'
51
52
53def test_create_time():
54    job = cg.EngineJob(
55        'a',
56        'b',
57        'steve',
58        EngineContext(),
59        _job=qtypes.QuantumJob(create_time=qtypes.timestamp_pb2.Timestamp(seconds=1581515101)),
60    )
61    assert job.create_time() == datetime.datetime(2020, 2, 12, 13, 45, 1)
62
63
64@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job')
65def test_update_time(get_job):
66    job = cg.EngineJob('a', 'b', 'steve', EngineContext())
67    get_job.return_value = qtypes.QuantumJob(
68        update_time=qtypes.timestamp_pb2.Timestamp(seconds=1581515101)
69    )
70    assert job.update_time() == datetime.datetime(2020, 2, 12, 13, 45, 1)
71    get_job.assert_called_once_with('a', 'b', 'steve', False)
72
73
74@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job')
75def test_description(get_job):
76    job = cg.EngineJob(
77        'a', 'b', 'steve', EngineContext(), _job=qtypes.QuantumJob(description='hello')
78    )
79    assert job.description() == 'hello'
80    get_job.return_value = qtypes.QuantumJob(description='hello')
81    assert cg.EngineJob('a', 'b', 'steve', EngineContext()).description() == 'hello'
82    get_job.assert_called_once_with('a', 'b', 'steve', False)
83
84
85@mock.patch('cirq_google.engine.engine_client.EngineClient.set_job_description')
86def test_set_description(set_job_description):
87    job = cg.EngineJob('a', 'b', 'steve', EngineContext())
88    set_job_description.return_value = qtypes.QuantumJob(description='world')
89    assert job.set_description('world').description() == 'world'
90    set_job_description.assert_called_with('a', 'b', 'steve', 'world')
91
92    set_job_description.return_value = qtypes.QuantumJob(description='')
93    assert job.set_description('').description() == ''
94    set_job_description.assert_called_with('a', 'b', 'steve', '')
95
96
97def test_labels():
98    job = cg.EngineJob(
99        'a', 'b', 'steve', EngineContext(), _job=qtypes.QuantumJob(labels={'t': '1'})
100    )
101    assert job.labels() == {'t': '1'}
102
103
104@mock.patch('cirq_google.engine.engine_client.EngineClient.set_job_labels')
105def test_set_labels(set_job_labels):
106    job = cg.EngineJob('a', 'b', 'steve', EngineContext())
107    set_job_labels.return_value = qtypes.QuantumJob(labels={'a': '1', 'b': '1'})
108    assert job.set_labels({'a': '1', 'b': '1'}).labels() == {'a': '1', 'b': '1'}
109    set_job_labels.assert_called_with('a', 'b', 'steve', {'a': '1', 'b': '1'})
110
111    set_job_labels.return_value = qtypes.QuantumJob()
112    assert job.set_labels({}).labels() == {}
113    set_job_labels.assert_called_with('a', 'b', 'steve', {})
114
115
116@mock.patch('cirq_google.engine.engine_client.EngineClient.add_job_labels')
117def test_add_labels(add_job_labels):
118    job = cg.EngineJob('a', 'b', 'steve', EngineContext(), _job=qtypes.QuantumJob(labels={}))
119    assert job.labels() == {}
120
121    add_job_labels.return_value = qtypes.QuantumJob(
122        labels={
123            'a': '1',
124        }
125    )
126    assert job.add_labels({'a': '1'}).labels() == {'a': '1'}
127    add_job_labels.assert_called_with('a', 'b', 'steve', {'a': '1'})
128
129    add_job_labels.return_value = qtypes.QuantumJob(labels={'a': '2', 'b': '1'})
130    assert job.add_labels({'a': '2', 'b': '1'}).labels() == {'a': '2', 'b': '1'}
131    add_job_labels.assert_called_with('a', 'b', 'steve', {'a': '2', 'b': '1'})
132
133
134@mock.patch('cirq_google.engine.engine_client.EngineClient.remove_job_labels')
135def test_remove_labels(remove_job_labels):
136    job = cg.EngineJob(
137        'a', 'b', 'steve', EngineContext(), _job=qtypes.QuantumJob(labels={'a': '1', 'b': '1'})
138    )
139    assert job.labels() == {'a': '1', 'b': '1'}
140
141    remove_job_labels.return_value = qtypes.QuantumJob(
142        labels={
143            'b': '1',
144        }
145    )
146    assert job.remove_labels(['a']).labels() == {'b': '1'}
147    remove_job_labels.assert_called_with('a', 'b', 'steve', ['a'])
148
149    remove_job_labels.return_value = qtypes.QuantumJob(labels={})
150    assert job.remove_labels(['a', 'b', 'c']).labels() == {}
151    remove_job_labels.assert_called_with('a', 'b', 'steve', ['a', 'b', 'c'])
152
153
154def test_processor_ids():
155    job = cg.EngineJob(
156        'a',
157        'b',
158        'steve',
159        EngineContext(),
160        _job=qtypes.QuantumJob(
161            scheduling_config=qtypes.SchedulingConfig(
162                processor_selector=qtypes.SchedulingConfig.ProcessorSelector(
163                    processor_names=['projects/a/processors/p']
164                )
165            )
166        ),
167    )
168    assert job.processor_ids() == ['p']
169
170
171@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job')
172def test_status(get_job):
173    qjob = qtypes.QuantumJob(
174        execution_status=qtypes.ExecutionStatus(state=qtypes.ExecutionStatus.State.RUNNING)
175    )
176    get_job.return_value = qjob
177
178    job = cg.EngineJob('a', 'b', 'steve', EngineContext())
179    assert job.status() == 'RUNNING'
180    get_job.assert_called_once()
181
182
183def test_failure():
184    job = cg.EngineJob(
185        'a',
186        'b',
187        'steve',
188        EngineContext(),
189        _job=qtypes.QuantumJob(
190            execution_status=qtypes.ExecutionStatus(
191                state=qtypes.ExecutionStatus.State.FAILURE,
192                failure=qtypes.ExecutionStatus.Failure(
193                    error_code=qtypes.ExecutionStatus.Failure.Code.SYSTEM_ERROR,
194                    error_message='boom',
195                ),
196            )
197        ),
198    )
199    assert job.failure() == ('SYSTEM_ERROR', 'boom')
200
201
202def test_failure_with_no_error():
203    job = cg.EngineJob(
204        'a',
205        'b',
206        'steve',
207        EngineContext(),
208        _job=qtypes.QuantumJob(
209            execution_status=qtypes.ExecutionStatus(
210                state=qtypes.ExecutionStatus.State.SUCCESS,
211            )
212        ),
213    )
214    assert not job.failure()
215
216
217@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job')
218def test_get_repetitions_and_sweeps(get_job):
219    job = cg.EngineJob('a', 'b', 'steve', EngineContext())
220    get_job.return_value = qtypes.QuantumJob(
221        run_context=_to_any(
222            v2.run_context_pb2.RunContext(
223                parameter_sweeps=[v2.run_context_pb2.ParameterSweep(repetitions=10)]
224            )
225        )
226    )
227    assert job.get_repetitions_and_sweeps() == (10, [cirq.UnitSweep])
228    get_job.assert_called_once_with('a', 'b', 'steve', True)
229
230
231@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job')
232def test_get_repetitions_and_sweeps_v1(get_job):
233    job = cg.EngineJob('a', 'b', 'steve', EngineContext())
234    get_job.return_value = qtypes.QuantumJob(
235        run_context=_to_any(
236            v1.program_pb2.RunContext(
237                parameter_sweeps=[v1.params_pb2.ParameterSweep(repetitions=10)]
238            )
239        )
240    )
241    with pytest.raises(ValueError, match='v1 RunContext is not supported'):
242        job.get_repetitions_and_sweeps()
243
244
245@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job')
246def test_get_repetitions_and_sweeps_unsupported(get_job):
247    job = cg.EngineJob('a', 'b', 'steve', EngineContext())
248    get_job.return_value = qtypes.QuantumJob(
249        run_context=qtypes.any_pb2.Any(type_url='type.googleapis.com/unknown.proto')
250    )
251    with pytest.raises(ValueError, match='unsupported run_context type: unknown.proto'):
252        job.get_repetitions_and_sweeps()
253
254
255def test_get_processor():
256    qjob = qtypes.QuantumJob(
257        execution_status=qtypes.ExecutionStatus(processor_name='projects/a/processors/p')
258    )
259
260    job = cg.EngineJob('a', 'b', 'steve', EngineContext(), _job=qjob)
261    assert job.get_processor().processor_id == 'p'
262
263
264def test_get_processor_no_processor():
265    qjob = qtypes.QuantumJob(execution_status=qtypes.ExecutionStatus())
266
267    job = cg.EngineJob('a', 'b', 'steve', EngineContext(), _job=qjob)
268    assert not job.get_processor()
269
270
271@mock.patch('cirq_google.engine.engine_client.EngineClient.get_calibration')
272def test_get_calibration(get_calibration):
273    qjob = qtypes.QuantumJob(
274        execution_status=qtypes.ExecutionStatus(
275            calibration_name='projects/a/processors/p/calibrations/123'
276        )
277    )
278    calibration = qtypes.QuantumCalibration(
279        data=_to_any(
280            Merge(
281                """
282    timestamp_ms: 123000,
283    metrics: [{
284        name: 'xeb',
285        targets: ['0_0', '0_1'],
286        values: [{
287            double_val: .9999
288        }]
289    }, {
290        name: 't1',
291        targets: ['0_0'],
292        values: [{
293            double_val: 321
294        }]
295    }, {
296        name: 'globalMetric',
297        values: [{
298            int32_val: 12300
299        }]
300    }]
301""",
302                v2.metrics_pb2.MetricsSnapshot(),
303            )
304        )
305    )
306    get_calibration.return_value = calibration
307
308    job = cg.EngineJob('a', 'b', 'steve', EngineContext(), _job=qjob)
309    assert list(job.get_calibration()) == ['xeb', 't1', 'globalMetric']
310    get_calibration.assert_called_once_with('a', 'p', 123)
311
312
313@mock.patch('cirq_google.engine.engine_client.EngineClient.get_calibration')
314def test_calibration__with_no_calibration(get_calibration):
315    job = cg.EngineJob(
316        'a',
317        'b',
318        'steve',
319        EngineContext(),
320        _job=qtypes.QuantumJob(
321            name='projects/project-id/programs/test/jobs/test',
322            execution_status={'state': 'SUCCESS'},
323        ),
324    )
325    calibration = job.get_calibration()
326    assert not calibration
327    assert not get_calibration.called
328
329
330@mock.patch('cirq_google.engine.engine_client.EngineClient.cancel_job')
331def test_cancel(cancel_job):
332    job = cg.EngineJob('a', 'b', 'steve', EngineContext())
333    job.cancel()
334    cancel_job.assert_called_once_with('a', 'b', 'steve')
335
336
337@mock.patch('cirq_google.engine.engine_client.EngineClient.delete_job')
338def test_delete(delete_job):
339    job = cg.EngineJob('a', 'b', 'steve', EngineContext())
340    job.delete()
341    delete_job.assert_called_once_with('a', 'b', 'steve')
342
343
344RESULTS = qtypes.QuantumResult(
345    result=_to_any(
346        Merge(
347            """
348sweep_results: [{
349        repetitions: 4,
350        parameterized_results: [{
351            params: {
352                assignments: {
353                    key: 'a'
354                    value: 1
355                }
356            },
357            measurement_results: {
358                key: 'q'
359                qubit_measurement_results: [{
360                  qubit: {
361                    id: '1_1'
362                  }
363                  results: '\006'
364                }]
365            }
366        },{
367            params: {
368                assignments: {
369                    key: 'a'
370                    value: 2
371                }
372            },
373            measurement_results: {
374                key: 'q'
375                qubit_measurement_results: [{
376                  qubit: {
377                    id: '1_1'
378                  }
379                  results: '\005'
380                }]
381            }
382        }]
383    }]
384""",
385            v2.result_pb2.Result(),
386        )
387    )
388)
389
390
391BATCH_RESULTS = qtypes.QuantumResult(
392    result=_to_any(
393        Merge(
394            """
395results: [{
396    sweep_results: [{
397        repetitions: 3,
398        parameterized_results: [{
399            params: {
400                assignments: {
401                    key: 'a'
402                    value: 1
403                }
404            },
405            measurement_results: {
406                key: 'q'
407                qubit_measurement_results: [{
408                  qubit: {
409                    id: '1_1'
410                  }
411                  results: '\006'
412                }]
413            }
414        },{
415            params: {
416                assignments: {
417                    key: 'a'
418                    value: 2
419                }
420            },
421            measurement_results: {
422                key: 'q'
423                qubit_measurement_results: [{
424                  qubit: {
425                    id: '1_1'
426                  }
427                  results: '\007'
428                }]
429            }
430        }]
431    }],
432    },{
433    sweep_results: [{
434        repetitions: 4,
435        parameterized_results: [{
436            params: {
437                assignments: {
438                    key: 'a'
439                    value: 3
440                }
441            },
442            measurement_results: {
443                key: 'q'
444                qubit_measurement_results: [{
445                  qubit: {
446                    id: '1_1'
447                  }
448                  results: '\013'
449                }]
450            }
451        },{
452            params: {
453                assignments: {
454                    key: 'a'
455                    value: 4
456                }
457            },
458            measurement_results: {
459                key: 'q'
460                qubit_measurement_results: [{
461                  qubit: {
462                    id: '1_1'
463                  }
464                  results: '\011'
465                }]
466            }
467        }]
468    }]
469}]
470""",
471            v2.batch_pb2.BatchResult(),
472        )
473    )
474)
475
476CALIBRATION_RESULT = qtypes.QuantumResult(
477    result=_to_any(
478        Merge(
479            """
480results: [{
481    code: ERROR_CALIBRATION_FAILED
482    error_message: 'uh oh'
483    token: 'abc'
484    valid_until_ms: 1234567891000
485    metrics: {
486        timestamp_ms: 1234567890000,
487        metrics: [{
488            name: 'theta',
489            targets: ['0_0', '0_1'],
490            values: [{
491                double_val: .9999
492            }]
493        }]
494    }
495}]
496""",
497            v2.calibration_pb2.FocusedCalibrationResult(),
498        )
499    )
500)
501
502
503@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results')
504def test_results(get_job_results):
505    qjob = qtypes.QuantumJob(
506        execution_status=qtypes.ExecutionStatus(state=qtypes.ExecutionStatus.State.SUCCESS)
507    )
508    get_job_results.return_value = RESULTS
509
510    job = cg.EngineJob('a', 'b', 'steve', EngineContext(), _job=qjob)
511    data = job.results()
512    assert len(data) == 2
513    assert str(data[0]) == 'q=0110'
514    assert str(data[1]) == 'q=1010'
515    get_job_results.assert_called_once_with('a', 'b', 'steve')
516
517
518@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results')
519def test_results_iter(get_job_results):
520    qjob = qtypes.QuantumJob(
521        execution_status=qtypes.ExecutionStatus(state=qtypes.ExecutionStatus.State.SUCCESS)
522    )
523    get_job_results.return_value = RESULTS
524
525    job = cg.EngineJob('a', 'b', 'steve', EngineContext(), _job=qjob)
526    results = [str(r) for r in job]
527    assert len(results) == 2
528    assert results[0] == 'q=0110'
529    assert results[1] == 'q=1010'
530
531
532@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results')
533def test_results_getitem(get_job_results):
534    qjob = qtypes.QuantumJob(
535        execution_status=qtypes.ExecutionStatus(state=qtypes.ExecutionStatus.State.SUCCESS)
536    )
537    get_job_results.return_value = RESULTS
538
539    job = cg.EngineJob('a', 'b', 'steve', EngineContext(), _job=qjob)
540    assert str(job[0]) == 'q=0110'
541    assert str(job[1]) == 'q=1010'
542    with pytest.raises(IndexError):
543        _ = job[2]
544
545
546@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results')
547def test_batched_results(get_job_results):
548    qjob = qtypes.QuantumJob(
549        execution_status=qtypes.ExecutionStatus(state=qtypes.ExecutionStatus.State.SUCCESS)
550    )
551    get_job_results.return_value = BATCH_RESULTS
552
553    job = cg.EngineJob('a', 'b', 'steve', EngineContext(), _job=qjob)
554    data = job.results()
555    assert len(data) == 4
556    assert str(data[0]) == 'q=011'
557    assert str(data[1]) == 'q=111'
558    assert str(data[2]) == 'q=1101'
559    assert str(data[3]) == 'q=1001'
560    get_job_results.assert_called_once_with('a', 'b', 'steve')
561
562    data = job.batched_results()
563    assert len(data) == 2
564    assert len(data[0]) == 2
565    assert len(data[1]) == 2
566    assert str(data[0][0]) == 'q=011'
567    assert str(data[0][1]) == 'q=111'
568    assert str(data[1][0]) == 'q=1101'
569    assert str(data[1][1]) == 'q=1001'
570
571
572@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results')
573def test_batched_results_not_a_batch(get_job_results):
574    qjob = qtypes.QuantumJob(
575        execution_status=qtypes.ExecutionStatus(state=qtypes.ExecutionStatus.State.SUCCESS)
576    )
577    get_job_results.return_value = RESULTS
578    job = cg.EngineJob('a', 'b', 'steve', EngineContext(), _job=qjob)
579    with pytest.raises(ValueError, match='batched_results'):
580        job.batched_results()
581
582
583@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results')
584def test_calibration_results(get_job_results):
585    qjob = qtypes.QuantumJob(
586        execution_status=qtypes.ExecutionStatus(state=qtypes.ExecutionStatus.State.SUCCESS)
587    )
588    get_job_results.return_value = CALIBRATION_RESULT
589    job = cg.EngineJob('a', 'b', 'steve', EngineContext(), _job=qjob)
590    data = job.calibration_results()
591    get_job_results.assert_called_once_with('a', 'b', 'steve')
592    assert len(data) == 1
593    assert data[0].code == v2.calibration_pb2.ERROR_CALIBRATION_FAILED
594    assert data[0].error_message == 'uh oh'
595    assert data[0].token == 'abc'
596    assert data[0].valid_until.timestamp() == 1234567891
597    assert len(data[0].metrics)
598    assert data[0].metrics['theta'] == {(cirq.GridQubit(0, 0), cirq.GridQubit(0, 1)): [0.9999]}
599
600
601@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results')
602def test_calibration_defaults(get_job_results):
603    qjob = qtypes.QuantumJob(
604        execution_status=qtypes.ExecutionStatus(state=qtypes.ExecutionStatus.State.SUCCESS)
605    )
606    result = v2.calibration_pb2.FocusedCalibrationResult()
607    result.results.add()
608    get_job_results.return_value = qtypes.QuantumResult(result=_to_any(result))
609    job = cg.EngineJob('a', 'b', 'steve', EngineContext(), _job=qjob)
610    data = job.calibration_results()
611    get_job_results.assert_called_once_with('a', 'b', 'steve')
612    assert len(data) == 1
613    assert data[0].code == v2.calibration_pb2.CALIBRATION_RESULT_UNSPECIFIED
614    assert data[0].error_message is None
615    assert data[0].token is None
616    assert data[0].valid_until is None
617    assert len(data[0].metrics) == 0
618
619
620@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results')
621def test_calibration_results_not_a_calibration(get_job_results):
622    qjob = qtypes.QuantumJob(
623        execution_status=qtypes.ExecutionStatus(state=qtypes.ExecutionStatus.State.SUCCESS)
624    )
625    get_job_results.return_value = RESULTS
626    job = cg.EngineJob('a', 'b', 'steve', EngineContext(), _job=qjob)
627    with pytest.raises(ValueError, match='calibration results'):
628        job.calibration_results()
629
630
631@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results')
632def test_results_len(get_job_results):
633    qjob = qtypes.QuantumJob(
634        execution_status=qtypes.ExecutionStatus(state=qtypes.ExecutionStatus.State.SUCCESS)
635    )
636    get_job_results.return_value = RESULTS
637
638    job = cg.EngineJob('a', 'b', 'steve', EngineContext(), _job=qjob)
639    assert len(job) == 2
640
641
642@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job')
643@mock.patch('time.sleep', return_value=None)
644def test_timeout(patched_time_sleep, get_job):
645    qjob = qtypes.QuantumJob(
646        execution_status=qtypes.ExecutionStatus(state=qtypes.ExecutionStatus.State.RUNNING)
647    )
648    get_job.return_value = qjob
649    job = cg.EngineJob('a', 'b', 'steve', EngineContext(timeout=500))
650    with pytest.raises(RuntimeError, match='Timed out'):
651        job.results()
652
653
654def test_str():
655    job = cg.EngineJob('a', 'b', 'steve', EngineContext())
656    assert str(job) == 'EngineJob(project_id=\'a\', program_id=\'b\', job_id=\'steve\')'
657