import numpy as np from numpy.testing import assert_allclose from pytest import raises as assert_raises from scipy.optimize import nnls class TestNNLS: def setup_method(self): self.rng = np.random.default_rng(1685225766635251) def test_nnls(self): a = np.arange(25.0).reshape(-1, 5) x = np.arange(5.0) y = a @ x x, res = nnls(a, y) assert res < 1e-7 assert np.linalg.norm((a @ x) - y) < 1e-7 def test_nnls_tall(self): a = self.rng.uniform(low=-10, high=10, size=[50, 10]) x = np.abs(self.rng.uniform(low=-2, high=2, size=[10])) x[::2] = 0 b = a @ x xact, rnorm = nnls(a, b, atol=500*np.linalg.norm(a, 1)*np.spacing(1.)) assert_allclose(xact, x, rtol=0., atol=1e-10) assert rnorm < 1e-12 def test_nnls_wide(self): # If too wide then problem becomes too ill-conditioned ans starts # emitting warnings, hence small m, n difference. a = self.rng.uniform(low=-10, high=10, size=[100, 120]) x = np.abs(self.rng.uniform(low=-2, high=2, size=[120])) x[::2] = 0 b = a @ x xact, rnorm = nnls(a, b, atol=500*np.linalg.norm(a, 1)*np.spacing(1.)) assert_allclose(xact, x, rtol=0., atol=1e-10) assert rnorm < 1e-12 def test_maxiter(self): # test that maxiter argument does stop iterations a = self.rng.uniform(size=(5, 10)) b = self.rng.uniform(size=5) with assert_raises(RuntimeError): nnls(a, b, maxiter=1) def test_nnls_inner_loop_case1(self): # See gh-20168 n = np.array( [3, 2, 0, 1, 1, 1, 3, 8, 14, 16, 29, 23, 41, 47, 53, 57, 67, 76, 103, 89, 97, 94, 85, 95, 78, 78, 78, 77, 73, 50, 50, 56, 68, 98, 95, 112, 134, 145, 158, 172, 213, 234, 222, 215, 216, 216, 206, 183, 135, 156, 110, 92, 63, 60, 52, 29, 20, 16, 12, 5, 5, 5, 1, 2, 3, 0, 2]) k = np.array( [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.7205812007860187, 0., 1.4411624015720375, 0.7205812007860187, 2.882324803144075, 5.76464960628815, 5.76464960628815, 12.249880413362318, 15.132205216506394, 20.176273622008523, 27.382085629868712, 48.27894045266326, 47.558359251877235, 68.45521407467177, 97.99904330689854, 108.0871801179028, 135.46926574777152, 140.51333415327366, 184.4687874012208, 171.49832578707245, 205.36564222401535, 244.27702706646033, 214.01261663344755, 228.42424064916793, 232.02714665309804, 205.36564222401535, 172.9394881886445, 191.67459940908097, 162.1307701768542, 153.48379576742198, 110.96950492104689, 103.04311171240067, 86.46974409432225, 60.528820866025576, 43.234872047161126, 23.779179625938617, 24.499760826724636, 17.29394881886445, 11.5292992125763, 5.76464960628815, 5.044068405502131, 3.6029060039300935, 0., 2.882324803144075, 0., 0., 0.]) d = np.array( [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.003889242101538, 0., 0.007606268390096, 0., 0.025457371599973, 0.036952882091577, 0., 0.08518359183449, 0.048201126400243, 0.196234990022205, 0.144116240157247, 0.171145134062442, 0., 0., 0.269555036538714, 0., 0., 0., 0.010893241091872, 0., 0., 0., 0., 0., 0., 0., 0., 0.048167058272886, 0.011238724891049, 0., 0., 0.055162603456078, 0., 0., 0., 0., 0.027753339088588, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) # The following code sets up a system of equations such that # $k_i-p_i*n_i$ is minimized for $p_i$ with weights $n_i$ and # monotonicity constraints on $p_i$. This translates to a system of # equations of the form $k_i - (d_1 + ... + d_i) * n_i$ and # non-negativity constraints on the $d_i$. If $n_i$ is zero the # system is modified such that $d_i - d_{i+1}$ is then minimized. N = len(n) A = np.diag(n) @ np.tril(np.ones((N, N))) w = n ** 0.5 nz = (n == 0).nonzero()[0] A[nz, nz] = 1 A[nz, np.minimum(nz + 1, N - 1)] = -1 w[nz] = 1 k[nz] = 0 W = np.diag(w) # Small perturbations can already make the infinite loop go away (just # uncomment the next line) # k = k + 1e-10 * np.random.normal(size=N) dact, _ = nnls(W @ A, W @ k) assert_allclose(dact, d, rtol=0., atol=1e-10) def test_nnls_inner_loop_case2(self): # See gh-20168 n = np.array( [1, 0, 1, 2, 2, 2, 3, 3, 5, 4, 14, 14, 19, 26, 36, 42, 36, 64, 64, 64, 81, 85, 85, 95, 95, 95, 75, 76, 69, 81, 62, 59, 68, 64, 71, 67, 74, 78, 118, 135, 153, 159, 210, 195, 218, 243, 236, 215, 196, 175, 185, 149, 144, 103, 104, 75, 56, 40, 32, 26, 17, 9, 12, 8, 2, 1, 1, 1]) k = np.array( [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.7064355064917867, 0., 0., 2.11930651947536, 0.7064355064917867, 0., 3.5321775324589333, 7.064355064917867, 11.302968103868587, 16.95445215580288, 20.486629688261814, 20.486629688261814, 37.44108184406469, 55.808405012851146, 78.41434122058831, 103.13958394780086, 105.965325973768, 125.74552015553803, 149.057891869767, 176.60887662294667, 197.09550631120848, 211.930651947536, 204.86629688261814, 233.8301526487814, 221.1143135319292, 195.6826352982249, 197.80194181770025, 191.4440222592742, 187.91184472681525, 144.11284332432447, 131.39700420747232, 116.5618585711448, 93.24948685691584, 89.01087381796512, 53.68909849337579, 45.211872415474346, 31.083162285638615, 24.72524272721253, 16.95445215580288, 9.890097090885014, 9.890097090885014, 2.8257420259671466, 2.8257420259671466, 1.4128710129835733, 0.7064355064917867, 1.4128710129835733]) d = np.array( [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.0021916146355674473, 0., 0., 0.011252740799789484, 0., 0., 0.037746623295934395, 0.03602328132946222, 0.09509167709829734, 0.10505765870204821, 0.01391037014274718, 0.0188296228752321, 0.20723559202324254, 0.3056220879462608, 0.13304643490426477, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.043185876949706214, 0.0037266261379722554, 0., 0., 0., 0., 0., 0.094797899357143, 0., 0., 0., 0., 0., 0., 0., 0., 0.23450935613672663, 0., 0., 0.07064355064917871]) # The following code sets up a system of equations such that # $k_i-p_i*n_i$ is minimized for $p_i$ with weights $n_i$ and # monotonicity constraints on $p_i$. This translates to a system of # equations of the form $k_i - (d_1 + ... + d_i) * n_i$ and # non-negativity constraints on the $d_i$. If $n_i$ is zero the # system is modified such that $d_i - d_{i+1}$ is then minimized. N = len(n) A = np.diag(n) @ np.tril(np.ones((N, N))) w = n ** 0.5 nz = (n == 0).nonzero()[0] A[nz, nz] = 1 A[nz, np.minimum(nz + 1, N - 1)] = -1 w[nz] = 1 k[nz] = 0 W = np.diag(w) dact, _ = nnls(W @ A, W @ k, atol=1e-7) p = np.cumsum(dact) assert np.all(dact >= 0) assert np.linalg.norm(k - n * p, ord=np.inf) < 28 assert_allclose(dact, d, rtol=0., atol=1e-10) def test_nnls_gh20302(self): # See gh-20302 A = np.array( [0.33408569134321575, 0.11136189711440525, 0.049140798007949286, 0.03712063237146841, 0.055680948557202625, 0.16642814595936478, 0.11095209730624318, 0.09791993030943345, 0.14793612974165757, 0.44380838922497273, 0.11099502671044059, 0.11099502671044059, 0.14693672599330593, 0.3329850801313218, 1.498432860590948, 0.0832374225132955, 0.11098323001772734, 0.19589481249472837, 0.5919105600945457, 3.5514633605672747, 0.06658716751427037, 0.11097861252378394, 0.24485832778293645, 0.9248217710315328, 6.936163282736496, 0.05547609388181014, 0.11095218776362029, 0.29376003042571264, 1.3314262531634435, 11.982836278470993, 0.047506113282944136, 0.11084759766020298, 0.3423969672933396, 1.8105107617833156, 19.010362998724812, 0.041507335004505576, 0.11068622667868154, 0.39074115283013344, 2.361306169145206, 28.335674029742474, 0.03682846280947718, 0.11048538842843154, 0.4387861797121048, 2.9831054875676517, 40.2719240821633, 0.03311278164362387, 0.11037593881207958, 0.4870572300443105, 3.6791979604026523, 55.187969406039784, 0.030079304092299915, 0.11029078167176636, 0.5353496017200152, 4.448394860761242, 73.3985152025605, 0.02545939709595835, 0.11032405408248619, 0.6328767609778363, 6.214921713313388, 121.19097340961108, 0.022080881724881523, 0.11040440862440762, 0.7307742886903428, 8.28033064683057, 186.30743955368786, 0.020715838214945492, 0.1104844704797093, 0.7800578384588346, 9.42800814760186, 226.27219554244465, 0.01843179728340054, 0.11059078370040323, 0.8784095015912599, 11.94380463964355, 322.48272527037585, 0.015812787653789077, 0.11068951357652354, 1.0257259848595766, 16.27135849574896, 512.5477926160922, 0.014438550529330062, 0.11069555405819713, 1.1234754801775881, 19.519316032262093, 673.4164031130423, 0.012760770585072577, 0.110593345070629, 1.2688431112524712, 24.920367089248398, 971.8943164806875, 0.011427556646114315, 0.11046638091243838, 1.413623342459821, 30.967408782453557, 1347.0822820367298, 0.010033330264470307, 0.11036663290917338, 1.6071533470570285, 40.063087746029936, 1983.122843428482, 0.008950061496507258, 0.11038409179025618, 1.802244865119193, 50.37194055362024, 2795.642700725923, 0.008071078821135658, 0.11030474388885401, 1.9956465761433504, 61.80742482572119, 3801.1566267818534, 0.007191031207777556, 0.11026247851925586, 2.238160187262168, 77.7718015155818, 5366.2543045751445, 0.00636834224248, 0.11038459886965334, 2.5328963107984297, 99.49331844784753, 7760.4788389321075, 0.005624259098118485, 0.11061042892966355, 2.879742607664547, 128.34496770138628, 11358.529641572684, 0.0050354270614989555, 0.11077939535297703, 3.2263279459292575, 160.85168205252265, 15924.316523199741, 0.0044997853165982555, 0.1109947044760903, 3.6244287189055613, 202.60233390369015, 22488.859063309606, 0.004023601950058174, 0.1113196539516095, 4.07713905729421, 255.6270320242126, 31825.565487014468, 0.0036024117873727094, 0.111674765408554, 4.582933773135057, 321.9583486728612, 44913.18963986413, 0.003201503089582304, 0.11205260813538065, 5.191786833370116, 411.79333489752383, 64857.45024636, 0.0028633044552448853, 0.11262330857296549, 5.864295861648949, 522.7223161899905, 92521.84996562831, 0.0025691897303891965, 0.11304434813712465, 6.584584405106342, 656.5615739804199, 129999.19164812315, 0.0022992911894424675, 0.11343169867916175, 7.4080129906658305, 828.2026426227864, 183860.98666225857, 0.0020449922071108764, 0.11383789952917212, 8.388975556433872, 1058.2750599896935, 265097.9025274183, 0.001831274615120854, 0.11414945100919989, 9.419351803810935, 1330.564050780237, 373223.2162438565, 0.0016363333454631633, 0.11454333418242145, 10.6143816579462, 1683.787012481595, 530392.9089317025, 0.0014598610433380044, 0.11484240207592301, 11.959688127956882, 2132.0874753402027, 754758.9662704318, 0.0012985240015312626, 0.11513579480243862, 13.514425358573531, 2715.5160990137824, 1083490.9235064993, 0.0011614735761289934, 0.11537304189548002, 15.171418602667567, 3415.195870828736, 1526592.554260445, 0.0010347472698811352, 0.11554677847006009, 17.080800985009617, 4322.412404600832, 2172012.2333119176, 0.0009232988811258664, 0.1157201264344419, 19.20004861829407, 5453.349531598553, 3075689.135821584, 0.0008228871862975205, 0.11602709326795038, 21.65735242414206, 6920.203923780365, 4390869.389638642, 0.00073528900066722, 0.11642075843897651, 24.40223571298994, 8755.811207598026, 6238515.485413593, 0.0006602764384729194, 0.11752920604817965, 27.694443541914293, 11171.386093291572, 8948280.260726549, 0.0005935538977939806, 0.11851292825953147, 31.325508920763063, 14174.185724149384, 12735505.873148222, 0.0005310755355633124, 0.11913794514470308, 35.381052949627765, 17987.010118815077, 18157886.71494382, 0.00047239949671590953, 0.1190446731724092, 39.71342528048061, 22679.438775422022, 25718483.571328573, 0.00041829129789387623, 0.11851586773659825, 44.45299332965028, 28542.57147989741, 36391778.63686921, 0.00037321512015419886, 0.11880681324908665, 50.0668539579632, 36118.26128449941, 51739409.29004541, 0.0003315539616702064, 0.1184752823034871, 56.04387059062639, 45383.29960621684, 72976345.76679668, 0.00029456064937920213, 0.11831519416731286, 62.91195073220101, 57265.53993693082, 103507463.43600245, 0.00026301867496859703, 0.11862142241083726, 70.8217262087034, 72383.14781936012, 146901598.49939138, 0.00023618734450420032, 0.11966825454879482, 80.26535457124461, 92160.51176984518, 210125966.835247, 0.00021165918071578316, 0.12043407382728061, 90.7169587544247, 116975.56852918258, 299515943.218972, 0.00018757727511329545, 0.11992440455576689, 101.49899864101785, 147056.26174166967, 423080865.0307836, 0.00016654469159895833, 0.11957908856805206, 113.65970431102812, 184937.67016486943, 597533612.3026931, 0.00014717439179415048, 0.11872067604728138, 126.77899683346702, 231758.58906776624, 841283678.3159915, 0.00012868496382376066, 0.1166314722122684, 139.93635237349534, 287417.30847929465, 1172231492.6328032, 0.00011225559452625302, 0.11427619522772557, 154.0034283704458, 355281.4912295324, 1627544511.322488, 9.879511142981067e-05, 0.11295574406808354, 170.96532050841535, 442971.0111288653, 2279085852.2580123, 8.71257780313587e-05, 0.11192758284428547, 190.35067416684697, 554165.2523674504, 3203629323.93623, 7.665069027765277e-05, 0.11060694607065294, 211.28835951100046, 690933.608546013, 4486577387.093535, 6.734021094824451e-05, 0.10915848194710433, 234.24338803525194, 860487.9079859136, 6276829044.8032465, 5.9191625040287665e-05, 0.10776821865668373, 259.7454711820425, 1071699.0387579766, 8780430224.544102, 5.1856803674907676e-05, 0.10606444911641115, 287.1843540288165, 1331126.3723998806, 12251687131.5685, 4.503421404759231e-05, 0.10347361247668461, 314.7338642485931, 1638796.0697522392, 16944331963.203278, 3.90470387455642e-05, 0.1007804070023012, 344.3427560918527, 2014064.4865519698, 23392351979.057854, 3.46557661636393e-05, 0.10046706610839032, 385.56603915081587, 2533036.2523656, 33044724430.235435, 3.148745865254635e-05, 0.1025441570117926, 442.09038234164746, 3262712.3882769793, 47815050050.199135, 2.9790762078715404e-05, 0.1089845379379672, 527.8068231298969, 4375751.903321453, 72035815708.42941, 2.8772639817606534e-05, 0.11823636789048445, 643.2048194503195, 5989838.001888927, 110764084330.93005, 2.7951691815106586e-05, 0.12903432664913705, 788.5500418523591, 8249371.000613411, 171368308481.2427, 2.6844392423114212e-05, 0.1392060709754626, 955.6296403631383, 11230229.319931043, 262063016295.25085, 2.499458273851386e-05, 0.14559344445184325, 1122.7022399726002, 14820229.698461473, 388475270970.9214, 2.337386729019776e-05, 0.15294300496886065, 1324.8158105672455, 19644861.137128454, 578442936182.7473, 2.0081014872174113e-05, 0.14760215298210377, 1436.2385042492353, 23923681.729276657, 791311658718.4193, 1.773374462991839e-05, 0.14642752940923615, 1600.5596278736678, 29949429.82503553, 1112815989293.9326, 1.5303115839590797e-05, 0.14194150045081785, 1742.873058605698, 36634451.931305364, 1529085389160.7544, 1.3148448731163076e-05, 0.13699368732998807, 1889.5284359054356, 44614279.74469635, 2091762812969.9607, 1.1739194407590062e-05, 0.13739553134643406, 2128.794599579694, 56462810.11822766, 2973783283306.8145, 1.0293367506254706e-05, 0.13533033372723272, 2355.372854690074, 70176508.28667311, 4151852759764.441, 9.678312586863569e-06, 0.14293577249119244, 2794.531827932675, 93528671.31952812, 6215821967224.52, -1.174086323572049e-05, 0.1429501325944908, 3139.4804810720925, 118031680.16618933, -6466892421886.174, -2.1188265307407812e-05, 0.1477108290912869, 3644.1133424610953, 153900132.62392554, -4828013117542.036, -8.614483025123122e-05, 0.16037100755883044, 4444.386620899393, 210846007.89660168, -1766340937974.433, 4.981445776141726e-05, 0.16053420251962536, 4997.558254401547, 266327328.4755411, 3862250287024.725, 1.8500019169456637e-05, 0.15448417164977674, 5402.289867444643, 323399508.1475582, 12152445411933.408, -5.647882376069748e-05, 0.1406372975946189, 5524.633133597753, 371512945.9909363, -4162951345292.1514, 2.8048523486337994e-05, 0.13183417571186926, 5817.462495763679, 439447252.3728975, 9294740538175.03]).reshape(89, 5) b = np.ones(89, dtype=np.float64) sol, rnorm = nnls(A, b) assert_allclose(sol, np.array([0.61124315, 8.22262829, 0., 0., 0.])) assert_allclose(rnorm, 1.0556460808977297)