Skip to content

Fix inf values for small/zero dv and improve dense batch performance#62

Open
armanschwarz wants to merge 5 commits intolocuslab:masterfrom
armanschwarz:master
Open

Fix inf values for small/zero dv and improve dense batch performance#62
armanschwarz wants to merge 5 commits intolocuslab:masterfrom
armanschwarz:master

Conversation

@armanschwarz
Copy link

@armanschwarz armanschwarz commented Jan 27, 2026

Get_step will produce inf values, which causes nans when dv is small/zero. This fixes the issue.

I also improved performance by constructing objects directly on the target device as opposed to utilising to or type_as.

@armanschwarz
Copy link
Author

Regarding the fix to create objects directly on the device, here is performance before 813d7ea:

CPU times: user 6min 13s, sys: 18.4 s, total: 6min 31s
Wall time: 53.5 s
Timer unit: 1e-09 s

Total time: 37.6778 s
File: /home/jovyan/repo/qpth/qpth/solvers/pdipm/batch.py
Function: forward at line 47

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    47                                           def forward(Q, p, G, h, A, b, Q_LU, S_LU, R, eps=1e-12, verbose=0, notImprovedLim=3,
    48                                                       maxIter=20, solver=KKTSolvers.LU_PARTIAL):
    49                                               """
    50                                               Q_LU, S_LU, R = pre_factor_kkt(Q, G, A)
    51                                               """
    52       100    1630493.0  16304.9      0.0      nineq, nz, neq, nBatch = get_sizes(G, A)
    53                                           
    54                                               # Find initial values
    55       100     395387.0   3953.9      0.0      if solver == KKTSolvers.LU_FULL:
    56                                                   D = torch.eye(nineq).repeat(nBatch, 1, 1).type_as(Q)
    57                                                   x, s, z, y = factor_solve_kkt(
    58                                                       Q, D, G, A, p,
    59                                                       torch.zeros(nBatch, nineq).type_as(Q),
    60                                                       -h, -b if b is not None else None)
    61       100     116534.0   1165.3      0.0      elif solver == KKTSolvers.LU_PARTIAL:
    62       100   49409941.0 494099.4      0.1          d = torch.ones(nBatch, nineq).type_as(Q)
    63       100 1541289666.0 1.54e+07      4.1          factor_kkt(S_LU, R, d)
    64       200  147986103.0 739930.5      0.4          x, s, z, y = solve_kkt(
    65       100      68647.0    686.5      0.0              Q_LU, d, G, A, S_LU,
    66       100   46634890.0 466348.9      0.1              p, torch.zeros(nBatch, nineq).type_as(Q),
    67       100    1737878.0  17378.8      0.0              -h, -b if neq > 0 else None)
    68                                               elif solver == KKTSolvers.IR_UNOPT:
    69                                                   D = torch.eye(nineq).repeat(nBatch, 1, 1).type_as(Q)
    70                                                   x, s, z, y = solve_kkt_ir(
    71                                                       Q, D, G, A, p,
    72                                                       torch.zeros(nBatch, nineq).type_as(Q),
    73                                                       -h, -b if b is not None else None)
    74                                               else:
    75                                                   assert False
    76                                           
    77                                               # Make all of the slack variables >= 1.
    78       100    2825373.0  28253.7      0.0      M = torch.min(s, 1)[0]
    79       100    4081494.0  40814.9      0.0      M = M.view(M.size(0), 1).repeat(1, nineq)
    80       100    1958995.0  19590.0      0.0      I = M < 0
    81       100  216788446.0 2.17e+06      0.6      s[I] -= M[I] - 1
    82                                           
    83                                               # Make all of the inequality dual variables >= 1.
    84       100    1599022.0  15990.2      0.0      M = torch.min(z, 1)[0]
    85       100    3611581.0  36115.8      0.0      M = M.view(M.size(0), 1).repeat(1, nineq)
    86       100    1167114.0  11671.1      0.0      I = M < 0
    87       100   13202572.0 132025.7      0.0      z[I] -= M[I] - 1
    88                                           
    89       100     183071.0   1830.7      0.0      best = {'resids': None, 'x': None, 'z': None, 's': None, 'y': None}
    90       100      58243.0    582.4      0.0      nNotImproved = 0
    91                                           
    92      1610    1369360.0    850.5      0.0      for i in range(maxIter):
    93                                                   # affine scaling direction
    94      6440   45578110.0   7077.3      0.1          rx = (torch.bmm(y.unsqueeze(1), A).squeeze(1) if neq > 0 else 0.) + \
    95      1610   44637261.0  27725.0      0.1              torch.bmm(z.unsqueeze(1), G).squeeze(1) + \
    96      1610   26389428.0  16390.9      0.1              torch.bmm(x.unsqueeze(1), Q.transpose(1, 2)).squeeze(1) + \
    97      1610     672152.0    417.5      0.0              p
    98      1610    1806139.0   1121.8      0.0          rs = z
    99      1610   40646512.0  25246.3      0.1          rz = torch.bmm(x.unsqueeze(1), G.transpose(1, 2)).squeeze(1) + s - h
   100      1610    1189556.0    738.9      0.0          ry = torch.bmm(x.unsqueeze(1), A.transpose(
   101      1610    1217040.0    755.9      0.0              1, 2)).squeeze(1) - b if neq > 0 else 0.0
   102      1610   75489355.0  46887.8      0.2          mu = torch.abs((s * z).sum(1).squeeze() / nineq)
   103      1610   85839911.0  53316.7      0.2          z_resid = torch.norm(rz, 2, 1).squeeze()
   104      1610    1127368.0    700.2      0.0          y_resid = torch.norm(ry, 2, 1).squeeze() if neq > 0 else 0
   105      1610   14568933.0   9049.0      0.0          pri_resid = y_resid + z_resid
   106      1610   29952432.0  18604.0      0.1          dual_resid = torch.norm(rx, 2, 1).squeeze()
   107      1610   30608727.0  19011.6      0.1          resids = pri_resid + dual_resid + nineq * mu
   108                                           
   109      1610   11098458.0   6893.5      0.0          d = z / s
   110      1610     697553.0    433.3      0.0          try:
   111      1610     2.43e+10 1.51e+07     64.4              factor_kkt(S_LU, R, d)
   112                                                   except:
   113                                                       return best['x'], best['y'], best['z'], best['s']
   114                                           
   115      1610    1852880.0   1150.9      0.0          if verbose == 1:
   116      3220  673138686.0 209049.3      1.8              print('iter: {}, pri_resid: {:.5e}, dual_resid: {:.5e}, mu: {:.5e}'.format(
   117      1610   50895587.0  31612.2      0.1                  i, pri_resid.mean(), dual_resid.mean(), mu.mean()))
   118      1610    2292259.0   1423.8      0.0          if best['resids'] is None:
   119       100      77403.0    774.0      0.0              best['resids'] = resids
   120       100    1425972.0  14259.7      0.0              best['x'] = x.clone()
   121       100     977238.0   9772.4      0.0              best['z'] = z.clone()
   122       100     911104.0   9111.0      0.0              best['s'] = s.clone()
   123       100     103180.0   1031.8      0.0              best['y'] = y.clone() if y is not None else None
   124       100     122190.0   1221.9      0.0              nNotImproved = 0
   125                                                   else:
   126      1510   25423980.0  16837.1      0.1              I = resids < best['resids']
   127      1510   79915031.0  52923.9      0.2              if I.sum() > 0:
   128      1510    1439209.0    953.1      0.0                  nNotImproved = 0
   129                                                       else:
   130                                                           nNotImproved += 1
   131      1510   60587143.0  40123.9      0.2              I_nz = I.repeat(nz, 1).t()
   132      1510   32135085.0  21281.5      0.1              I_nineq = I.repeat(nineq, 1).t()
   133      1510  131731116.0  87239.1      0.3              best['resids'][I] = resids[I]
   134      1510  143826594.0  95249.4      0.4              best['x'][I_nz] = x[I_nz]
   135      1510  123887233.0  82044.5      0.3              best['z'][I_nineq] = z[I_nineq]
   136      1510  120426556.0  79752.7      0.3              best['s'][I_nineq] = s[I_nineq]
   137      1510    1538651.0   1019.0      0.0              if neq > 0:
   138                                                           I_neq = I.repeat(neq, 1).t()
   139                                                           best['y'][I_neq] = y[I_neq]
   140      1610  116121380.0  72125.1      0.3          if nNotImproved == notImprovedLim or best['resids'].max() < eps or mu.min() > 1e32:
   141       100    2421264.0  24212.6      0.0              if best['resids'].max() > 1. and verbose >= 0:
   142                                                           print(INACC_ERR)
   143       100     479819.0   4798.2      0.0              return best['x'], best['y'], best['z'], best['s']
   144                                           
   145      1510    2707634.0   1793.1      0.0          if solver == KKTSolvers.LU_FULL:
   146                                                       D = bdiag(d)
   147                                                       dx_aff, ds_aff, dz_aff, dy_aff = factor_solve_kkt(
   148                                                           Q, D, G, A, rx, rs, rz, ry)
   149      1510    1055612.0    699.1      0.0          elif solver == KKTSolvers.LU_PARTIAL:
   150      3020  639044751.0 211604.2      1.7              dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt(
   151      1510    1035217.0    685.6      0.0                  Q_LU, d, G, A, S_LU, rx, rs, rz, ry)
   152                                                   elif solver == KKTSolvers.IR_UNOPT:
   153                                                       D = bdiag(d)
   154                                                       dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt_ir(
   155                                                           Q, D, G, A, rx, rs, rz, ry)
   156                                                   else:
   157                                                       assert False
   158                                           
   159                                                   # compute centering directions
   160      4530 3361254088.0 741998.7      8.9          alpha = torch.min(torch.min(get_step(z, dz_aff),
   161      1510  133405411.0  88348.0      0.4                                      get_step(s, ds_aff)),
   162      1510   60668180.0  40177.6      0.2                            torch.ones(nBatch).type_as(Q))
   163      1510   56900574.0  37682.5      0.2          alpha_nineq = alpha.repeat(nineq, 1).t()
   164      1510   32354284.0  21426.7      0.1          t1 = s + alpha_nineq * ds_aff
   165      1510   17597627.0  11654.1      0.0          t2 = z + alpha_nineq * dz_aff
   166      1510   35543881.0  23539.0      0.1          t3 = torch.sum(t1 * t2, 1).squeeze()
   167      1510   21223797.0  14055.5      0.1          t4 = torch.sum(s * z, 1).squeeze()
   168      1510   53595660.0  35493.8      0.1          sig = (t3 / t4)**3
   169                                           
   170      1510  330720265.0 219020.0      0.9          rx = torch.zeros(nBatch, nz).type_as(Q)
   171      1510  105729318.0  70019.4      0.3          rs = ((-mu * sig).repeat(nineq, 1).t() + ds_aff * dz_aff) / s
   172      1510  404487539.0 267872.5      1.1          rz = torch.zeros(nBatch, nineq).type_as(Q)
   173      1510    9815742.0   6500.5      0.0          ry = torch.zeros(nBatch, neq).type_as(Q) if neq > 0 else torch.Tensor()
   174                                           
   175      1510    1870310.0   1238.6      0.0          if solver == KKTSolvers.LU_FULL:
   176                                                       D = bdiag(d)
   177                                                       dx_cor, ds_cor, dz_cor, dy_cor = factor_solve_kkt(
   178                                                           Q, D, G, A, rx, rs, rz, ry)
   179      1510     875921.0    580.1      0.0          elif solver == KKTSolvers.LU_PARTIAL:
   180      3020  408014566.0 135104.2      1.1              dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt(
   181      1510     710202.0    470.3      0.0                  Q_LU, d, G, A, S_LU, rx, rs, rz, ry)
   182                                                   elif solver == KKTSolvers.IR_UNOPT:
   183                                                       D = bdiag(d)
   184                                                       dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt_ir(
   185                                                           Q, D, G, A, rx, rs, rz, ry)
   186                                                   else:
   187                                                       assert False
   188                                           
   189      1510   11216295.0   7428.0      0.0          dx = dx_aff + dx_cor
   190      1510    9547409.0   6322.8      0.0          ds = ds_aff + ds_cor
   191      1510    8413444.0   5571.8      0.0          dz = dz_aff + dz_cor
   192      1510     981777.0    650.2      0.0          dy = dy_aff + dy_cor if neq > 0 else None
   193      4530 3352158506.0 739990.8      8.9          alpha = torch.min(0.999 * torch.min(get_step(z, dz),
   194      1510  131241374.0  86914.8      0.3                                              get_step(s, ds)),
   195      1510   51353775.0  34009.1      0.1                            torch.ones(nBatch).type_as(Q))
   196      1510   54058283.0  35800.2      0.1          alpha_nineq = alpha.repeat(nineq, 1).t()
   197      1510     987981.0    654.3      0.0          alpha_neq = alpha.repeat(neq, 1).t() if neq > 0 else None
   198      1510   30697741.0  20329.6      0.1          alpha_nz = alpha.repeat(nz, 1).t()
   199                                           
   200      1510   26311755.0  17425.0      0.1          x += alpha_nz * dx
   201      1510   16051131.0  10629.9      0.0          s += alpha_nineq * ds
   202      1510   14866201.0   9845.2      0.0          z += alpha_nineq * dz
   203      1510    1870041.0   1238.4      0.0          y = y + alpha_neq * dy if neq > 0 else None
   204                                           
   205                                               if best['resids'].max() > 1. and verbose >= 0:
   206                                                   print(INACC_ERR)
   207                                               return best['x'], best['y'], best['z'], best['s']

And after the fix:

CPU times: user 26.7 s, sys: 4.54 s, total: 31.3 s
Wall time: 31.8 s
Timer unit: 1e-09 s

Total time: 15.1151 s
File: /home/jovyan/repo/qpth/qpth/solvers/pdipm/batch.py
Function: forward at line 58

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    58                                           @profile
    59                                           def forward(
    60                                               Q,
    61                                               p,
    62                                               G,
    63                                               h,
    64                                               A,
    65                                               b,
    66                                               Q_LU,
    67                                               S_LU,
    68                                               R,
    69                                               eps=1e-12,
    70                                               verbose=0,
    71                                               notImprovedLim=3,
    72                                               maxIter=20,
    73                                               solver=KKTSolvers.LU_PARTIAL,
    74                                           ):
    75                                               """
    76                                               Q_LU, S_LU, R = pre_factor_kkt(Q, G, A)
    77                                               """
    78       100     595559.0   5955.6      0.0      nineq, nz, neq, nBatch = get_sizes(G, A)
    79                                           
    80                                               # Find initial values
    81       100     123522.0   1235.2      0.0      if solver == KKTSolvers.LU_FULL:
    82                                                   D = torch.eye(nineq, device=Q.device, dtype=Q.dtype).repeat(nBatch, 1, 1)
    83                                                   x, s, z, y = factor_solve_kkt(
    84                                                       Q,
    85                                                       D,
    86                                                       G,
    87                                                       A,
    88                                                       p,
    89                                                       torch.zeros(nBatch, nineq, device=Q.device, dtype=Q.dtype),
    90                                                       -h,
    91                                                       -b if b is not None else None,
    92                                                   )
    93       100      57112.0    571.1      0.0      elif solver == KKTSolvers.LU_PARTIAL:
    94       100     737943.0   7379.4      0.0          d = torch.ones(nBatch, nineq, device=Q.device, dtype=Q.dtype)
    95       100  649259835.0 6.49e+06      4.3          factor_kkt(S_LU, R, d)
    96       200   27969181.0 139845.9      0.2          x, s, z, y = solve_kkt(
    97       100      45269.0    452.7      0.0              Q_LU,
    98       100      39681.0    396.8      0.0              d,
    99       100      37619.0    376.2      0.0              G,
   100       100      40178.0    401.8      0.0              A,
   101       100      34563.0    345.6      0.0              S_LU,
   102       100      36075.0    360.8      0.0              p,
   103       100    1147858.0  11478.6      0.0              torch.zeros(nBatch, nineq, device=Q.device, dtype=Q.dtype),
   104       100     772722.0   7727.2      0.0              -h,
   105       100      51492.0    514.9      0.0              -b if neq > 0 else None,
   106                                                   )
   107                                               elif solver == KKTSolvers.IR_UNOPT:
   108                                                   D = torch.eye(nineq, device=Q.device, dtype=Q.dtype).repeat(nBatch, 1, 1)
   109                                                   x, s, z, y = solve_kkt_ir(
   110                                                       Q,
   111                                                       D,
   112                                                       G,
   113                                                       A,
   114                                                       p,
   115                                                       torch.zeros(nBatch, nineq, device=Q.device, dtype=Q.dtype),
   116                                                       -h,
   117                                                       -b if b is not None else None,
   118                                                   )
   119                                               else:
   120                                                   assert False
   121                                           
   122                                               # Make all of the slack variables >= 1.
   123       100    1873977.0  18739.8      0.0      M = torch.min(s, 1)[0]
   124       100    3523400.0  35234.0      0.0      M = M.view(M.size(0), 1).repeat(1, nineq)
   125       100    1219031.0  12190.3      0.0      I = M < 0
   126       100  259643186.0  2.6e+06      1.7      s[I] -= M[I] - 1
   127                                           
   128                                               # Make all of the inequality dual variables >= 1.
   129       100    1519293.0  15192.9      0.0      M = torch.min(z, 1)[0]
   130       100    3642680.0  36426.8      0.0      M = M.view(M.size(0), 1).repeat(1, nineq)
   131       100    1059719.0  10597.2      0.0      I = M < 0
   132       100   12284957.0 122849.6      0.1      z[I] -= M[I] - 1
   133                                           
   134       100     134735.0   1347.3      0.0      best = {"resids": None, "x": None, "z": None, "s": None, "y": None}
   135       100      39783.0    397.8      0.0      nNotImproved = 0
   136                                           
   137      1610    1052881.0    654.0      0.0      for i in range(maxIter):
   138                                                   # affine scaling direction
   139      1610    1461464.0    907.7      0.0          rx = (
   140      6440  193593045.0  30061.0      1.3              (torch.bmm(y.unsqueeze(1), A).squeeze(1) if neq > 0 else 0.0)
   141      1610   42256377.0  26246.2      0.3              + torch.bmm(z.unsqueeze(1), G).squeeze(1)
   142      1610   23945156.0  14872.8      0.2              + torch.bmm(x.unsqueeze(1), Q.transpose(1, 2)).squeeze(1)
   143      1610     690027.0    428.6      0.0              + p
   144                                                   )
   145      1610    1620338.0   1006.4      0.0          rs = z
   146      1610   38517396.0  23923.8      0.3          rz = torch.bmm(x.unsqueeze(1), G.transpose(1, 2)).squeeze(1) + s - h
   147      1610    1020870.0    634.1      0.0          ry = (
   148                                                       torch.bmm(x.unsqueeze(1), A.transpose(1, 2)).squeeze(1) - b
   149      1610     841027.0    522.4      0.0              if neq > 0
   150      1610     672098.0    417.5      0.0              else 0.0
   151                                                   )
   152      1610   60639583.0  37664.3      0.4          mu = torch.abs((s * z).sum(1).squeeze() / nineq)
   153      1610   57764659.0  35878.7      0.4          z_resid = torch.norm(rz, 2, 1).squeeze()
   154      1610     861435.0    535.1      0.0          y_resid = torch.norm(ry, 2, 1).squeeze() if neq > 0 else 0
   155      1610   12962418.0   8051.2      0.1          pri_resid = y_resid + z_resid
   156      1610   28807170.0  17892.7      0.2          dual_resid = torch.norm(rx, 2, 1).squeeze()
   157      1610   28344721.0  17605.4      0.2          resids = pri_resid + dual_resid + nineq * mu
   158                                           
   159      1610   10068586.0   6253.8      0.1          d = z / s
   160      1610     697226.0    433.1      0.0          try:
   161      1610 3804675523.0 2.36e+06     25.2              factor_kkt(S_LU, R, d)
   162                                                   except:
   163                                                       return best["x"], best["y"], best["z"], best["s"]
   164                                           
   165      1610    1265254.0    785.9      0.0          if verbose == 1:
   166      3220    6030172.0   1872.7      0.0              print(
   167      3220  674609952.0 209506.2      4.5                  "iter: {}, pri_resid: {:.5e}, dual_resid: {:.5e}, mu: {:.5e}".format(
   168      1610   37813292.0  23486.5      0.3                      i, pri_resid.mean(), dual_resid.mean(), mu.mean()
   169                                                           )
   170                                                       )
   171      1610    1062415.0    659.9      0.0          if best["resids"] is None:
   172       100      53128.0    531.3      0.0              best["resids"] = resids
   173       100    1176072.0  11760.7      0.0              best["x"] = x.clone()
   174       100     834192.0   8341.9      0.0              best["z"] = z.clone()
   175       100     700872.0   7008.7      0.0              best["s"] = s.clone()
   176       100      59897.0    599.0      0.0              best["y"] = y.clone() if y is not None else None
   177       100      56003.0    560.0      0.0              nNotImproved = 0
   178                                                   else:
   179      1510   16551501.0  10961.3      0.1              I = resids < best["resids"]
   180      1510   57670480.0  38192.4      0.4              if I.sum() > 0:
   181      1510     761814.0    504.5      0.0                  nNotImproved = 0
   182                                                       else:
   183                                                           nNotImproved += 1
   184      1510   50105634.0  33182.5      0.3              I_nz = I.repeat(nz, 1).t()
   185      1510   28591745.0  18934.9      0.2              I_nineq = I.repeat(nineq, 1).t()
   186      1510  122544536.0  81155.3      0.8              best["resids"][I] = resids[I]
   187      1510  124263209.0  82293.5      0.8              best["x"][I_nz] = x[I_nz]
   188      1510  122240148.0  80953.7      0.8              best["z"][I_nineq] = z[I_nineq]
   189      1510  110653241.0  73280.3      0.7              best["s"][I_nineq] = s[I_nineq]
   190      1510    1013284.0    671.0      0.0              if neq > 0:
   191                                                           I_neq = I.repeat(neq, 1).t()
   192                                                           best["y"][I_neq] = y[I_neq]
   193                                                   if (
   194      1610     687720.0    427.2      0.0              nNotImproved == notImprovedLim
   195      1610   52354694.0  32518.4      0.3              or best["resids"].max() < eps
   196      1510   37295929.0  24699.3      0.2              or mu.min() > 1e32
   197                                                   ):
   198       100    2217776.0  22177.8      0.0              if best["resids"].max() > 1.0 and verbose >= 0:
   199                                                           print(INACC_ERR)
   200       100     490619.0   4906.2      0.0              return best["x"], best["y"], best["z"], best["s"]
   201                                           
   202      1510    1442628.0    955.4      0.0          if solver == KKTSolvers.LU_FULL:
   203                                                       D = bdiag(d)
   204                                                       dx_aff, ds_aff, dz_aff, dy_aff = factor_solve_kkt(
   205                                                           Q, D, G, A, rx, rs, rz, ry
   206                                                       )
   207      1510     779161.0    516.0      0.0          elif solver == KKTSolvers.LU_PARTIAL:
   208      3020  438808610.0 145300.9      2.9              dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt(
   209      1510     618877.0    409.9      0.0                  Q_LU, d, G, A, S_LU, rx, rs, rz, ry
   210                                                       )
   211                                                   elif solver == KKTSolvers.IR_UNOPT:
   212                                                       D = bdiag(d)
   213                                                       dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt_ir(Q, D, G, A, rx, rs, rz, ry)
   214                                                   else:
   215                                                       assert False
   216                                           
   217                                                   # compute centering directions
   218      3020   11450865.0   3791.7      0.1          alpha = torch.min(
   219      1510 3547126708.0 2.35e+06     23.5              torch.min(get_step(z, dz_aff), get_step(s, ds_aff)),
   220      1510   11029930.0   7304.6      0.1              torch.ones(nBatch, device=Q.device, dtype=Q.dtype),
   221                                                   )
   222      1510   57194385.0  37877.1      0.4          alpha_nineq = alpha.repeat(nineq, 1).t()
   223      1510   27082989.0  17935.8      0.2          t1 = s + alpha_nineq * ds_aff
   224      1510   16450757.0  10894.5      0.1          t2 = z + alpha_nineq * dz_aff
   225      1510   27472447.0  18193.7      0.2          t3 = torch.sum(t1 * t2, 1).squeeze()
   226      1510   18967494.0  12561.3      0.1          t4 = torch.sum(s * z, 1).squeeze()
   227      1510   49642361.0  32875.7      0.3          sig = (t3 / t4) ** 3
   228                                           
   229      1510   15946838.0  10560.8      0.1          rx = torch.zeros(nBatch, nz, device=Q.device, dtype=Q.dtype)
   230      1510   85805120.0  56824.6      0.6          rs = ((-mu * sig).repeat(nineq, 1).t() + ds_aff * dz_aff) / s
   231      1510   12354851.0   8182.0      0.1          rz = torch.zeros(nBatch, nineq, device=Q.device, dtype=Q.dtype)
   232      1510     659108.0    436.5      0.0          ry = (
   233                                                       torch.zeros(nBatch, neq, device=Q.device, dtype=Q.dtype)
   234      1510     821734.0    544.2      0.0              if neq > 0
   235      1510    4653594.0   3081.9      0.0              else torch.Tensor()
   236                                                   )
   237                                           
   238      1510    1317668.0    872.6      0.0          if solver == KKTSolvers.LU_FULL:
   239                                                       D = bdiag(d)
   240                                                       dx_cor, ds_cor, dz_cor, dy_cor = factor_solve_kkt(
   241                                                           Q, D, G, A, rx, rs, rz, ry
   242                                                       )
   243      1510     735030.0    486.8      0.0          elif solver == KKTSolvers.LU_PARTIAL:
   244      3020  366129351.0 121234.9      2.4              dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt(
   245      1510     583433.0    386.4      0.0                  Q_LU, d, G, A, S_LU, rx, rs, rz, ry
   246                                                       )
   247                                                   elif solver == KKTSolvers.IR_UNOPT:
   248                                                       D = bdiag(d)
   249                                                       dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt_ir(Q, D, G, A, rx, rs, rz, ry)
   250                                                   else:
   251                                                       assert False
   252                                           
   253      1510   10147074.0   6719.9      0.1          dx = dx_aff + dx_cor
   254      1510    8546676.0   5660.1      0.1          ds = ds_aff + ds_cor
   255      1510    7521079.0   4980.8      0.0          dz = dz_aff + dz_cor
   256      1510     753348.0    498.9      0.0          dy = dy_aff + dy_cor if neq > 0 else None
   257      3020   11138040.0   3688.1      0.1          alpha = torch.min(
   258      1510 3501114935.0 2.32e+06     23.2              0.999 * torch.min(get_step(z, dz), get_step(s, ds)),
   259      1510   10956440.0   7255.9      0.1              torch.ones(nBatch, device=Q.device, dtype=Q.dtype),
   260                                                   )
   261      1510   55463998.0  36731.1      0.4          alpha_nineq = alpha.repeat(nineq, 1).t()
   262      1510     976753.0    646.9      0.0          alpha_neq = alpha.repeat(neq, 1).t() if neq > 0 else None
   263      1510   29093953.0  19267.5      0.2          alpha_nz = alpha.repeat(nz, 1).t()
   264                                           
   265      1510   24312012.0  16100.7      0.2          x += alpha_nz * dx
   266      1510   15258568.0  10105.0      0.1          s += alpha_nineq * ds
   267      1510   14041913.0   9299.3      0.1          z += alpha_nineq * dz
   268      1510    1172061.0    776.2      0.0          y = y + alpha_neq * dy if neq > 0 else None
   269                                           
   270                                               if best["resids"].max() > 1.0 and verbose >= 0:
   271                                                   print(INACC_ERR)
   272                                               return best["x"], best["y"], best["z"], best["s"]

@armanschwarz armanschwarz changed the title Fix inf values for small/zero dv Fix inf values for small/zero dv and improve performance Feb 2, 2026
@armanschwarz armanschwarz changed the title Fix inf values for small/zero dv and improve performance Fix inf values for small/zero dv and improve dense batch performance Feb 2, 2026
@armanschwarz
Copy link
Author

624e92b fixes a performance regression introduced by my first fix. Performance is now further improved:

CPU times: user 25.1 s, sys: 3.21 s, total: 28.3 s
Wall time: 28.4 s
Timer unit: 1e-09 s

Total time: 14.1164 s
File: /home/jovyan/repo/qpth/qpth/solvers/pdipm/batch.py
Function: forward at line 58

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    58                                           @profile
    59                                           def forward(
    60                                               Q,
    61                                               p,
    62                                               G,
    63                                               h,
    64                                               A,
    65                                               b,
    66                                               Q_LU,
    67                                               S_LU,
    68                                               R,
    69                                               eps=1e-12,
    70                                               verbose=0,
    71                                               notImprovedLim=3,
    72                                               maxIter=20,
    73                                               solver=KKTSolvers.LU_PARTIAL,
    74                                           ):
    75                                               """
    76                                               Q_LU, S_LU, R = pre_factor_kkt(Q, G, A)
    77                                               """
    78       100     610572.0   6105.7      0.0      nineq, nz, neq, nBatch = get_sizes(G, A)
    79                                           
    80                                               # Find initial values
    81       100     186959.0   1869.6      0.0      if solver == KKTSolvers.LU_FULL:
    82                                                   D = torch.eye(nineq, device=Q.device, dtype=Q.dtype).repeat(nBatch, 1, 1)
    83                                                   x, s, z, y = factor_solve_kkt(
    84                                                       Q,
    85                                                       D,
    86                                                       G,
    87                                                       A,
    88                                                       p,
    89                                                       torch.zeros(nBatch, nineq, device=Q.device, dtype=Q.dtype),
    90                                                       -h,
    91                                                       -b if b is not None else None,
    92                                                   )
    93       100      58959.0    589.6      0.0      elif solver == KKTSolvers.LU_PARTIAL:
    94       100     784550.0   7845.5      0.0          d = torch.ones(nBatch, nineq, device=Q.device, dtype=Q.dtype)
    95       100  649381676.0 6.49e+06      4.6          factor_kkt(S_LU, R, d)
    96       200   25682714.0 128413.6      0.2          x, s, z, y = solve_kkt(
    97       100      48730.0    487.3      0.0              Q_LU,
    98       100      39417.0    394.2      0.0              d,
    99       100      50275.0    502.8      0.0              G,
   100       100      41208.0    412.1      0.0              A,
   101       100      36934.0    369.3      0.0              S_LU,
   102       100      39577.0    395.8      0.0              p,
   103       100    1135169.0  11351.7      0.0              torch.zeros(nBatch, nineq, device=Q.device, dtype=Q.dtype),
   104       100     756849.0   7568.5      0.0              -h,
   105       100      49690.0    496.9      0.0              -b if neq > 0 else None,
   106                                                   )
   107                                               elif solver == KKTSolvers.IR_UNOPT:
   108                                                   D = torch.eye(nineq, device=Q.device, dtype=Q.dtype).repeat(nBatch, 1, 1)
   109                                                   x, s, z, y = solve_kkt_ir(
   110                                                       Q,
   111                                                       D,
   112                                                       G,
   113                                                       A,
   114                                                       p,
   115                                                       torch.zeros(nBatch, nineq, device=Q.device, dtype=Q.dtype),
   116                                                       -h,
   117                                                       -b if b is not None else None,
   118                                                   )
   119                                               else:
   120                                                   assert False
   121                                           
   122                                               # Make all of the slack variables >= 1.
   123       100    2210877.0  22108.8      0.0      M = torch.min(s, 1)[0]
   124       100    3492308.0  34923.1      0.0      M = M.view(M.size(0), 1).repeat(1, nineq)
   125       100    1195791.0  11957.9      0.0      I = M < 0
   126       100  264020540.0 2.64e+06      1.9      s[I] -= M[I] - 1
   127                                           
   128                                               # Make all of the inequality dual variables >= 1.
   129       100    1576061.0  15760.6      0.0      M = torch.min(z, 1)[0]
   130       100    3562532.0  35625.3      0.0      M = M.view(M.size(0), 1).repeat(1, nineq)
   131       100    1082585.0  10825.9      0.0      I = M < 0
   132       100   13032905.0 130329.1      0.1      z[I] -= M[I] - 1
   133                                           
   134       100     141724.0   1417.2      0.0      best = {"resids": None, "x": None, "z": None, "s": None, "y": None}
   135       100      48103.0    481.0      0.0      nNotImproved = 0
   136                                           
   137      1610    1091619.0    678.0      0.0      for i in range(maxIter):
   138                                                   # affine scaling direction
   139      1610    1604908.0    996.8      0.0          rx = (
   140      6440   39723236.0   6168.2      0.3              (torch.bmm(y.unsqueeze(1), A).squeeze(1) if neq > 0 else 0.0)
   141      1610   36575004.0  22717.4      0.3              + torch.bmm(z.unsqueeze(1), G).squeeze(1)
   142      1610   23443406.0  14561.1      0.2              + torch.bmm(x.unsqueeze(1), Q.transpose(1, 2)).squeeze(1)
   143      1610     695153.0    431.8      0.0              + p
   144                                                   )
   145      1610    1580369.0    981.6      0.0          rs = z
   146      1610   38552626.0  23945.7      0.3          rz = torch.bmm(x.unsqueeze(1), G.transpose(1, 2)).squeeze(1) + s - h
   147      1610    1150347.0    714.5      0.0          ry = (
   148                                                       torch.bmm(x.unsqueeze(1), A.transpose(1, 2)).squeeze(1) - b
   149      1610     891169.0    553.5      0.0              if neq > 0
   150      1610     802715.0    498.6      0.0              else 0.0
   151                                                   )
   152      1610   59586540.0  37010.3      0.4          mu = torch.abs((s * z).sum(1).squeeze() / nineq)
   153      1610   49989673.0  31049.5      0.4          z_resid = torch.norm(rz, 2, 1).squeeze()
   154      1610     877710.0    545.2      0.0          y_resid = torch.norm(ry, 2, 1).squeeze() if neq > 0 else 0
   155      1610   13076259.0   8121.9      0.1          pri_resid = y_resid + z_resid
   156      1610   29756328.0  18482.2      0.2          dual_resid = torch.norm(rx, 2, 1).squeeze()
   157      1610   29340074.0  18223.6      0.2          resids = pri_resid + dual_resid + nineq * mu
   158                                           
   159      1610   10253320.0   6368.5      0.1          d = z / s
   160      1610     711106.0    441.7      0.0          try:
   161      1610 9770235437.0 6.07e+06     69.2              factor_kkt(S_LU, R, d)
   162                                                   except:
   163                                                       return best["x"], best["y"], best["z"], best["s"]
   164                                           
   165      1610    1398953.0    868.9      0.0          if verbose == 1:
   166      3220    5492765.0   1705.8      0.0              print(
   167      3220  674902860.0 209597.2      4.8                  "iter: {}, pri_resid: {:.5e}, dual_resid: {:.5e}, mu: {:.5e}".format(
   168      1610   34347814.0  21334.0      0.2                      i, pri_resid.mean(), dual_resid.mean(), mu.mean()
   169                                                           )
   170                                                       )
   171      1610    1168924.0    726.0      0.0          if best["resids"] is None:
   172       100      59046.0    590.5      0.0              best["resids"] = resids
   173       100    1138736.0  11387.4      0.0              best["x"] = x.clone()
   174       100     935273.0   9352.7      0.0              best["z"] = z.clone()
   175       100     729661.0   7296.6      0.0              best["s"] = s.clone()
   176       100      63666.0    636.7      0.0              best["y"] = y.clone() if y is not None else None
   177       100      60334.0    603.3      0.0              nNotImproved = 0
   178                                                   else:
   179      1510   16609327.0  10999.6      0.1              I = resids < best["resids"]
   180      1510   58272283.0  38590.9      0.4              if I.sum() > 0:
   181      1510     859380.0    569.1      0.0                  nNotImproved = 0
   182                                                       else:
   183                                                           nNotImproved += 1
   184      1510   49846238.0  33010.8      0.4              I_nz = I.repeat(nz, 1).t()
   185      1510   28435608.0  18831.5      0.2              I_nineq = I.repeat(nineq, 1).t()
   186      1510  124196803.0  82249.5      0.9              best["resids"][I] = resids[I]
   187      1510  134662703.0  89180.6      1.0              best["x"][I_nz] = x[I_nz]
   188      1510  121791763.0  80656.8      0.9              best["z"][I_nineq] = z[I_nineq]
   189      1510  118980208.0  78794.8      0.8              best["s"][I_nineq] = s[I_nineq]
   190      1510    1140056.0    755.0      0.0              if neq > 0:
   191                                                           I_neq = I.repeat(neq, 1).t()
   192                                                           best["y"][I_neq] = y[I_neq]
   193                                                   if (
   194      1610     923761.0    573.8      0.0              nNotImproved == notImprovedLim
   195      1610   53913735.0  33486.8      0.4              or best["resids"].max() < eps
   196      1510   38610267.0  25569.7      0.3              or mu.min() > 1e32
   197                                                   ):
   198       100    2336342.0  23363.4      0.0              if best["resids"].max() > 1.0 and verbose >= 0:
   199                                                           print(INACC_ERR)
   200       100     486941.0   4869.4      0.0              return best["x"], best["y"], best["z"], best["s"]
   201                                           
   202      1510    1468060.0    972.2      0.0          if solver == KKTSolvers.LU_FULL:
   203                                                       D = bdiag(d)
   204                                                       dx_aff, ds_aff, dz_aff, dy_aff = factor_solve_kkt(
   205                                                           Q, D, G, A, rx, rs, rz, ry
   206                                                       )
   207      1510     841150.0    557.1      0.0          elif solver == KKTSolvers.LU_PARTIAL:
   208      3020  391554533.0 129653.8      2.8              dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt(
   209      1510     671452.0    444.7      0.0                  Q_LU, d, G, A, S_LU, rx, rs, rz, ry
   210                                                       )
   211                                                   elif solver == KKTSolvers.IR_UNOPT:
   212                                                       D = bdiag(d)
   213                                                       dx_aff, ds_aff, dz_aff, dy_aff = solve_kkt_ir(Q, D, G, A, rx, rs, rz, ry)
   214                                                   else:
   215                                                       assert False
   216                                           
   217                                                   # compute centering directions
   218      3020   11591856.0   3838.4      0.1          alpha = torch.min(
   219      1510  151601836.0 100398.6      1.1              torch.min(get_step(z, dz_aff), get_step(s, ds_aff)),
   220      1510   13620547.0   9020.2      0.1              torch.ones(nBatch, device=Q.device, dtype=Q.dtype),
   221                                                   )
   222      1510   51589823.0  34165.4      0.4          alpha_nineq = alpha.repeat(nineq, 1).t()
   223      1510   25767807.0  17064.8      0.2          t1 = s + alpha_nineq * ds_aff
   224      1510   16605482.0  10997.0      0.1          t2 = z + alpha_nineq * dz_aff
   225      1510   27576867.0  18262.8      0.2          t3 = torch.sum(t1 * t2, 1).squeeze()
   226      1510   19652910.0  13015.2      0.1          t4 = torch.sum(s * z, 1).squeeze()
   227      1510   35365608.0  23420.9      0.3          sig = (t3 / t4) ** 3
   228                                           
   229      1510   15995163.0  10592.8      0.1          rx = torch.zeros(nBatch, nz, device=Q.device, dtype=Q.dtype)
   230      1510   84580072.0  56013.3      0.6          rs = ((-mu * sig).repeat(nineq, 1).t() + ds_aff * dz_aff) / s
   231      1510   12132344.0   8034.7      0.1          rz = torch.zeros(nBatch, nineq, device=Q.device, dtype=Q.dtype)
   232      1510     748193.0    495.5      0.0          ry = (
   233                                                       torch.zeros(nBatch, neq, device=Q.device, dtype=Q.dtype)
   234      1510     757772.0    501.8      0.0              if neq > 0
   235      1510    4193746.0   2777.3      0.0              else torch.Tensor()
   236                                                   )
   237                                           
   238      1510    1281750.0    848.8      0.0          if solver == KKTSolvers.LU_FULL:
   239                                                       D = bdiag(d)
   240                                                       dx_cor, ds_cor, dz_cor, dy_cor = factor_solve_kkt(
   241                                                           Q, D, G, A, rx, rs, rz, ry
   242                                                       )
   243      1510     712359.0    471.8      0.0          elif solver == KKTSolvers.LU_PARTIAL:
   244      3020  358488746.0 118704.9      2.5              dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt(
   245      1510     613210.0    406.1      0.0                  Q_LU, d, G, A, S_LU, rx, rs, rz, ry
   246                                                       )
   247                                                   elif solver == KKTSolvers.IR_UNOPT:
   248                                                       D = bdiag(d)
   249                                                       dx_cor, ds_cor, dz_cor, dy_cor = solve_kkt_ir(Q, D, G, A, rx, rs, rz, ry)
   250                                                   else:
   251                                                       assert False
   252                                           
   253      1510   10329545.0   6840.8      0.1          dx = dx_aff + dx_cor
   254      1510    8645353.0   5725.4      0.1          ds = ds_aff + ds_cor
   255      1510    7799071.0   5164.9      0.1          dz = dz_aff + dz_cor
   256      1510     801879.0    531.0      0.0          dy = dy_aff + dy_cor if neq > 0 else None
   257      3020   10983041.0   3636.8      0.1          alpha = torch.min(
   258      1510  156087178.0 103369.0      1.1              0.999 * torch.min(get_step(z, dz), get_step(s, ds)),
   259      1510   12378565.0   8197.7      0.1              torch.ones(nBatch, device=Q.device, dtype=Q.dtype),
   260                                                   )
   261      1510   49432395.0  32736.7      0.4          alpha_nineq = alpha.repeat(nineq, 1).t()
   262      1510     959859.0    635.7      0.0          alpha_neq = alpha.repeat(neq, 1).t() if neq > 0 else None
   263      1510   27959596.0  18516.3      0.2          alpha_nz = alpha.repeat(nz, 1).t()
   264                                           
   265      1510   22714889.0  15043.0      0.2          x += alpha_nz * dx
   266      1510   15118504.0  10012.3      0.1          s += alpha_nineq * ds
   267      1510   14236796.0   9428.3      0.1          z += alpha_nineq * dz
   268      1510    1014714.0    672.0      0.0          y = y + alpha_neq * dy if neq > 0 else None
   269                                           
   270                                               if best["resids"].max() > 1.0 and verbose >= 0:
   271                                                   print(INACC_ERR)
   272                                               return best["x"], best["y"], best["z"], best["s"]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant