1from cvxpy.atoms.affine.bmat import bmat
2from cvxpy.atoms.affine.hstack import hstack
3from cvxpy.atoms.affine.promote import promote
4from cvxpy.atoms.affine.reshape import reshape
5from cvxpy.atoms.log_sum_exp import log_sum_exp
6
7
8def add_canon(expr, args):
9    if expr.is_scalar():
10        return log_sum_exp(hstack(args)), []
11
12    rows = []
13    summands = [
14       promote(s, expr.shape) if s.is_scalar() else s for s in args]
15    if len(expr.shape) == 1:
16        for i in range(expr.shape[0]):
17            row = []
18            row.append(
19              log_sum_exp(hstack([summand[i] for summand in summands])))
20            rows.append(row)
21        return reshape(bmat(rows), expr.shape), []
22    else:
23        for i in range(expr.shape[0]):
24            row = []
25            for j in range(expr.shape[1]):
26                row.append(
27                  log_sum_exp(hstack([summand[i, j] for summand in summands])))
28            rows.append(row)
29        return reshape(bmat(rows), expr.shape), []
30