1from decimal import Decimal
2import os
3import random
4import warnings
5
6from sqlalchemy import __version__
7from sqlalchemy import Column
8from sqlalchemy import create_engine
9from sqlalchemy import ForeignKey
10from sqlalchemy import Integer
11from sqlalchemy import Numeric
12from sqlalchemy import String
13from sqlalchemy.ext.declarative import declarative_base
14from sqlalchemy.orm import relationship
15from sqlalchemy.orm import Session
16
17warnings.filterwarnings("ignore", r".*Decimal objects natively")  # noqa
18
19# speed up cdecimal if available
20try:
21    import cdecimal
22    import sys
23
24    sys.modules["decimal"] = cdecimal
25except ImportError:
26    pass
27
28Base = declarative_base()
29
30
31class Employee(Base):
32    __tablename__ = "employee"
33
34    id = Column(Integer, primary_key=True)
35    name = Column(String(100), nullable=False)
36    type = Column(String(50), nullable=False)
37
38    __mapper_args__ = {"polymorphic_on": type}
39
40
41class Boss(Employee):
42    __tablename__ = "boss"
43
44    id = Column(Integer, ForeignKey("employee.id"), primary_key=True)
45    golf_average = Column(Numeric)
46
47    __mapper_args__ = {"polymorphic_identity": "boss"}
48
49
50class Grunt(Employee):
51    __tablename__ = "grunt"
52
53    id = Column(Integer, ForeignKey("employee.id"), primary_key=True)
54    savings = Column(Numeric)
55
56    employer_id = Column(Integer, ForeignKey("boss.id"))
57
58    employer = relationship(
59        "Boss", backref="employees", primaryjoin=Boss.id == employer_id
60    )
61
62    __mapper_args__ = {"polymorphic_identity": "grunt"}
63
64
65if os.path.exists("orm2010.db"):
66    os.remove("orm2010.db")
67# use a file based database so that cursor.execute() has some
68# palpable overhead.
69engine = create_engine("sqlite:///orm2010.db")
70
71Base.metadata.create_all(engine)
72
73sess = Session(engine)
74
75
76def runit(status, factor=1, query_runs=5):
77    num_bosses = 100 * factor
78    num_grunts = num_bosses * 100
79
80    bosses = [
81        Boss(name="Boss %d" % i, golf_average=Decimal(random.randint(40, 150)))
82        for i in range(num_bosses)
83    ]
84
85    sess.add_all(bosses)
86    status("Added %d boss objects" % num_bosses)
87
88    grunts = [
89        Grunt(
90            name="Grunt %d" % i,
91            savings=Decimal(random.randint(5000000, 15000000) / 100),
92        )
93        for i in range(num_grunts)
94    ]
95    status("Added %d grunt objects" % num_grunts)
96
97    while grunts:
98        # this doesn't associate grunts with bosses evenly,
99        # just associates lots of them with a relatively small
100        # handful of bosses
101        batch_size = 100
102        batch_num = (num_grunts - len(grunts)) / batch_size
103        boss = sess.query(Boss).filter_by(name="Boss %d" % batch_num).first()
104        for grunt in grunts[0:batch_size]:
105            grunt.employer = boss
106
107        grunts = grunts[batch_size:]
108
109    sess.commit()
110    status("Associated grunts w/ bosses and committed")
111
112    # do some heavier reading
113    for i in range(query_runs):
114        status("Heavy query run #%d" % (i + 1))
115
116        report = []
117
118        # load all the Grunts, print a report with their name, stats,
119        # and their bosses' stats.
120        for grunt in sess.query(Grunt):
121            report.append(
122                (
123                    grunt.name,
124                    grunt.savings,
125                    grunt.employer.name,
126                    grunt.employer.golf_average,
127                )
128            )
129
130        sess.close()  # close out the session
131
132
133def run_with_profile(runsnake=False, dump=False):
134    import cProfile
135    import pstats
136
137    filename = "orm2010.profile"
138
139    if os.path.exists("orm2010.profile"):
140        os.remove("orm2010.profile")
141
142    def status(msg):
143        print(msg)
144
145    cProfile.runctx("runit(status)", globals(), locals(), filename)
146    stats = pstats.Stats(filename)
147
148    counts_by_methname = dict(
149        (key[2], stats.stats[key][0]) for key in stats.stats
150    )
151
152    print("SQLA Version: %s" % __version__)
153    print("Total calls %d" % stats.total_calls)
154    print("Total cpu seconds: %.2f" % stats.total_tt)
155    print(
156        "Total execute calls: %d"
157        % counts_by_methname[
158            "<method 'execute' of 'sqlite3.Cursor' " "objects>"
159        ]
160    )
161    print(
162        "Total executemany calls: %d"
163        % counts_by_methname.get(
164            "<method 'executemany' of 'sqlite3.Cursor' " "objects>", 0
165        )
166    )
167
168    if dump:
169        stats.sort_stats("time", "calls")
170        stats.print_stats()
171
172    if runsnake:
173        os.system("runsnake %s" % filename)
174
175
176def run_with_time():
177    import time
178
179    now = time.time()
180
181    def status(msg):
182        print("%d - %s" % (time.time() - now, msg))
183
184    runit(status, 10)
185    print("Total time: %d" % (time.time() - now))
186
187
188if __name__ == "__main__":
189    import argparse
190
191    parser = argparse.ArgumentParser()
192    parser.add_argument(
193        "--profile",
194        action="store_true",
195        help="run shorter test suite w/ cprofilng",
196    )
197    parser.add_argument(
198        "--dump",
199        action="store_true",
200        help="dump full call profile (implies --profile)",
201    )
202    parser.add_argument(
203        "--runsnake",
204        action="store_true",
205        help="invoke runsnakerun (implies --profile)",
206    )
207
208    args = parser.parse_args()
209
210    args.profile = args.profile or args.dump or args.runsnake
211
212    if args.profile:
213        run_with_profile(runsnake=args.runsnake, dump=args.dump)
214    else:
215        run_with_time()
216