1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
5
6from knack.log import get_logger
7from knack.util import CLIError
8from azure.cli.core.commands import LongRunningOperation
9
10from ._constants import ACR_TASK_YAML_DEFAULT_NAME
11from ._stream_utils import stream_logs
12from ._utils import (
13    validate_managed_registry,
14    get_validate_platform,
15    get_custom_registry_credentials,
16    get_yaml_template,
17    prepare_source_location
18)
19from ._client_factory import cf_acr_registries_tasks
20
21RUN_NOT_SUPPORTED = 'Run is only available for managed registries.'
22
23logger = get_logger(__name__)
24
25
26def acr_run(cmd,  # pylint: disable=too-many-locals
27            client,
28            registry_name,
29            source_location,
30            agent_pool_name=None,
31            file=None,
32            values=None,
33            set_value=None,
34            set_secret=None,
35            cmd_value=None,
36            no_format=False,
37            no_logs=False,
38            no_wait=False,
39            timeout=None,
40            resource_group_name=None,
41            platform=None,
42            auth_mode=None,
43            log_template=None):
44
45    _, resource_group_name = validate_managed_registry(
46        cmd, registry_name, resource_group_name, RUN_NOT_SUPPORTED)
47
48    if cmd_value and file:
49        raise CLIError(
50            "Azure Container Registry can run with either "
51            "--cmd myCommand /dev/null or "
52            "-f myFile mySourceLocation, but not both.")
53
54    client_registries = cf_acr_registries_tasks(cmd.cli_ctx)
55    source_location = prepare_source_location(
56        cmd, source_location, client_registries, registry_name, resource_group_name)
57
58    platform_os, platform_arch, platform_variant = get_validate_platform(cmd, platform)
59
60    EncodedTaskRunRequest, FileTaskRunRequest, PlatformProperties = cmd.get_models(
61        'EncodedTaskRunRequest', 'FileTaskRunRequest', 'PlatformProperties', operation_group='runs')
62
63    if source_location:
64        request = FileTaskRunRequest(
65            task_file_path=file if file else ACR_TASK_YAML_DEFAULT_NAME,
66            values_file_path=values,
67            values=(set_value if set_value else []) + (set_secret if set_secret else []),
68            source_location=source_location,
69            timeout=timeout,
70            platform=PlatformProperties(
71                os=platform_os,
72                architecture=platform_arch,
73                variant=platform_variant
74            ),
75            credentials=get_custom_registry_credentials(
76                cmd=cmd,
77                auth_mode=auth_mode
78            ),
79            agent_pool_name=agent_pool_name,
80            log_template=log_template
81        )
82    else:
83        yaml_template = get_yaml_template(cmd_value, timeout, file)
84        import base64
85        request = EncodedTaskRunRequest(
86            encoded_task_content=base64.b64encode(yaml_template.encode()).decode(),
87            values=(set_value if set_value else []) + (set_secret if set_secret else []),
88            source_location=source_location,
89            timeout=timeout,
90            platform=PlatformProperties(
91                os=platform_os,
92                architecture=platform_arch,
93                variant=platform_variant
94            ),
95            credentials=get_custom_registry_credentials(
96                cmd=cmd,
97                auth_mode=auth_mode
98            ),
99            agent_pool_name=agent_pool_name,
100            log_template=log_template
101        )
102
103    queued = LongRunningOperation(cmd.cli_ctx)(client_registries.begin_schedule_run(
104        resource_group_name=resource_group_name,
105        registry_name=registry_name,
106        run_request=request))
107
108    run_id = queued.run_id
109    logger.warning("Queued a run with ID: %s", run_id)
110
111    if no_wait:
112        return queued
113
114    logger.warning("Waiting for an agent...")
115
116    if no_logs:
117        from ._run_polling import get_run_with_polling
118        return get_run_with_polling(cmd, client, run_id, registry_name, resource_group_name)
119
120    return stream_logs(cmd, client, run_id, registry_name, resource_group_name, timeout, no_format, True)
121