1# coding: utf-8
2
3#-------------------------------------------------------------------------
4# Copyright (c) Microsoft Corporation. All rights reserved.
5# Licensed under the MIT License. See License.txt in the project root for
6# license information.
7#--------------------------------------------------------------------------
8import unittest
9
10import azure.mgmt.sql
11
12from devtools_testutils import (
13    AzureMgmtTestCase, ResourceGroupPreparer,
14    AzureMgmtPreparer, FakeResource
15)
16
17
18def get_server_params(location):
19    return {
20        'location': 'westus2', #location, # "self.region" is 'west-us' by default
21        'version': '12.0',
22        'administrator_login': 'mysecretname',
23        'administrator_login_password': 'HusH_Sec4et'
24    }
25
26
27class SqlServerPreparer(AzureMgmtPreparer):
28    def __init__(self, name_prefix='mypysqlserverx'):
29        super(SqlServerPreparer, self).__init__(name_prefix, 24)
30
31    def create_resource(self, name, **kwargs):
32        if self.is_live:
33            async_server_create = self.test_class_instance.client.servers.create_or_update(
34                kwargs['resource_group'].name,
35                name,
36                get_server_params(kwargs['location'])
37            )
38            server = async_server_create.result()
39        else:
40            server = FakeResource(name=name, id='')
41
42        return {
43            'server': server
44        }
45
46
47class MgmtSqlTest(AzureMgmtTestCase):
48
49    def setUp(self):
50        super(MgmtSqlTest, self).setUp()
51        self.client = self.create_mgmt_client(
52            azure.mgmt.sql.SqlManagementClient
53        )
54
55    @ResourceGroupPreparer()
56    def test_server(self, resource_group, location):
57        server_name = self.get_resource_name('tstpysqlserverx')
58
59        async_server_create = self.client.servers.create_or_update(
60            resource_group.name, # Created by the framework
61            server_name,
62            get_server_params(location),
63        )
64        server = async_server_create.result()
65        self.assertEqual(server.name, server_name)
66
67        server = self.client.servers.get(
68            resource_group.name,
69            server_name
70        )
71        self.assertEqual(server.name, server_name)
72
73        my_servers = list(self.client.servers.list_by_resource_group(resource_group.name))
74        self.assertEqual(len(my_servers), 1)
75        self.assertEqual(my_servers[0].name, server_name)
76
77        my_servers = list(self.client.servers.list())
78        self.assertTrue(len(my_servers) >= 1)
79        self.assertTrue(any(server.name == server_name for server in my_servers))
80
81        usages = list(self.client.server_usages.list_by_server(resource_group.name, server_name))
82        self.assertTrue(any(usage.name == 'server_dtu_quota' for usage in usages))
83
84        firewall_rule_name = self.get_resource_name('firewallrule')
85        firewall_rule = self.client.firewall_rules.create_or_update(
86            resource_group.name,
87            server_name,
88            firewall_rule_name,
89            "123.123.123.123",
90            "123.123.123.124"
91        )
92        self.assertEqual(firewall_rule.name, firewall_rule_name)
93        self.assertEqual(firewall_rule.start_ip_address, "123.123.123.123")
94        self.assertEqual(firewall_rule.end_ip_address, "123.123.123.124")
95
96        self.client.servers.delete(resource_group.name, server_name, polling=False)
97
98    @ResourceGroupPreparer()
99    @SqlServerPreparer()
100    def test_database(self, resource_group, location, server):
101        db_name = self.get_resource_name('pyarmdb')
102
103        async_db_create = self.client.databases.create_or_update(
104            resource_group.name,
105            server.name,
106            db_name,
107            {
108                'location': 'westus2' # location
109            }
110        )
111        database = async_db_create.result() # Wait for completion and return created object
112        self.assertEqual(database.name, db_name)
113
114        db = self.client.databases.get(
115            resource_group.name,
116            server.name,
117            db_name
118        )
119        self.assertEqual(db.name, db_name)
120
121        my_dbs = list(self.client.databases.list_by_server(resource_group.name, server.name))
122        print([db.name for db in my_dbs])
123        self.assertEqual(len(my_dbs), 2)
124        self.assertTrue(any(db.name == 'master' for db in my_dbs))
125        self.assertTrue(any(db.name == db_name for db in my_dbs))
126
127        usages = list(self.client.database_usages.list_by_database(resource_group.name, server.name, db_name))
128        self.assertTrue(any(usage.name == 'database_size' for usage in usages))
129
130        self.client.databases.delete(resource_group.name, server.name, db_name).wait()
131
132
133#------------------------------------------------------------------------------
134if __name__ == '__main__':
135    unittest.main()
136