Skip to content

Commit c917575

Browse files
committed
Adapt the coverage we can use now
1 parent 0ca3d16 commit c917575

1 file changed

Lines changed: 73 additions & 51 deletions

File tree

Tests/MatftTests/LinAlgTest.swift

Lines changed: 73 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,57 @@
1-
// Disabled temporally until we provide better WASM support
2-
#if !os(WASI)
31
import XCTest
42

53
@testable import Matft
64

75
final class LinAlgTests: XCTestCase {
8-
6+
7+
// MARK: - Tests requiring gesv_ (not available on WASM)
8+
#if !os(WASI)
99
func testSolve() {
1010
do{
1111
let coef = MfArray([[3,2],[1,2]])
1212
let b = MfArray([7,1])
13-
13+
1414
XCTAssertEqual(try Matft.linalg.solve(coef, b: b), MfArray([ 3.0, -1.0], mftype: .Float))
1515
}
1616

1717
do{
1818
let coef = MfArray([[3,1],[1,2]], mftype: .Double)
1919
let b = MfArray([9,8])
20-
20+
2121
XCTAssertEqual(try Matft.linalg.solve(coef, b: b), MfArray([ 2.0, 3.0], mftype: .Double))
2222
}
23-
23+
2424
do{
2525
let coef = MfArray([[1,2],[2,4]])
2626
let b = MfArray([-1,-2])
27-
27+
2828
XCTAssertThrowsError(try Matft.linalg.solve(coef, b: b))
2929
}
30-
30+
3131
do{
3232
let coef = MfArray([[1,2],[2,4]])
3333
let b = MfArray([-1,-3])
34-
34+
3535
XCTAssertThrowsError(try Matft.linalg.solve(coef, b: b))
3636
}
3737
}
38-
38+
#endif
39+
3940

4041
func testInv(){
42+
// Float test - requires sgetrf_/sgetri_ (not available on WASM)
43+
#if !os(WASI)
4144
do{
4245
let a = MfArray([[1, 2], [3, 4]])
4346
XCTAssertEqual(try Matft.linalg.inv(a), MfArray([[-2.0 , 1.0 ],
4447
[ 1.5, -0.5]], mftype: .Float))
4548
}
49+
#endif
50+
// Double test - uses dgetrf_/dgetri_ (available on WASM via CLAPACK)
4651
do{
4752
let a = MfArray([[[1.0, 2.0],
4853
[3.0, 4.0]],
49-
54+
5055
[[1.0, 3.0],
5156
[3.0, 5.0]]], mftype: .Double)
5257
XCTAssertEqual(try Matft.linalg.inv(a), MfArray([[[-2.0 , 1.0 ],
@@ -55,25 +60,30 @@ final class LinAlgTests: XCTestCase {
5560
[ 0.75, -0.25]]], mftype: .Double))
5661
}
5762
}
58-
63+
5964
func testDet(){
60-
65+
// Float test - requires sgetrf_ (not available on WASM)
66+
#if !os(WASI)
6167
do{
6268
let a = MfArray([[1, 2], [3, 4]])
6369
XCTAssertEqual(try Matft.linalg.det(a), MfArray([-2.0], mftype: .Float))
6470
}
65-
71+
#endif
72+
73+
// Double test - uses dgetrf_ (available on WASM via CLAPACK)
6674
do{
6775
let a = MfArray([[[1.0, 2.0],
6876
[3.0, 4.0]],
69-
77+
7078
[[1.0, 3.0],
7179
[3.0, 5.0]]], mftype: .Double)
7280
XCTAssertEqual(try Matft.linalg.det(a), MfArray([-2.0, -4.0], mftype: .Double))
7381
}
7482
}
75-
83+
7684
func testEigen(){
85+
// Float test - requires sgeev_ (not available on WASM)
86+
#if !os(WASI)
7787
do{
7888
let a = MfArray([[1, -1], [1, 1]])
7989
let ret = try! Matft.linalg.eigen(a)
@@ -90,8 +100,10 @@ final class LinAlgTests: XCTestCase {
90100
[0.0, 0.0]], mftype: .Float))
91101
XCTAssertEqual(ret.rvecIm, MfArray([[0.0, 0.0],
92102
[-0.70710677, 0.70710677]], mftype: .Float))
93-
103+
94104
}
105+
#endif
106+
// Double tests - use dgeev_ (available on WASM via pure Swift fallback)
95107
do{
96108
let a = MfArray([[[1,0,0],
97109
[0,2,0],
@@ -114,7 +126,7 @@ final class LinAlgTests: XCTestCase {
114126
XCTAssertEqual(ret.rvecIm, MfArray([[[0.0, 0.0, 0.0],
115127
[0.0, 0.0, 0.0],
116128
[0.0, 0.0, 0.0]]], mftype: .Double))
117-
129+
118130
}
119131
do{
120132
let a = MfArray([[0, -1],
@@ -133,27 +145,29 @@ final class LinAlgTests: XCTestCase {
133145
[0.0, 0.0]], mftype: .Double))
134146
XCTAssertEqual(ret.rvecIm, MfArray([[0.0, 0.0],
135147
[-0.707106781186547, 0.707106781186547]], mftype: .Double))
136-
148+
137149
}
138150
}
139-
151+
152+
// MARK: - Tests requiring SVD (not available on WASM)
153+
#if !os(WASI)
140154
func testSVD(){
141155
do{
142156
let a = MfArray([[2, 4, 1, 3],
143157
[1, 5, 3, 2],
144158
[5, 7, 0, 7]])
145159
let ret = try! Matft.linalg.svd(a)
146-
160+
147161
// v and rt is not unique?
148162
XCTAssertEqual(ret.v *& ret.v.T, Matft.eye(dim: 3, mftype: .Float))
149163
//astype is for avoiding minute error
150164
XCTAssertEqual(ret.s.astype(.Float), MfArray([1.33853840e+01, 3.58210781e+00, 5.07054122e-16], mftype: .Float))
151165
XCTAssertEqual(ret.rt *& ret.rt.T , Matft.eye(dim: 4, mftype: .Float))
152-
166+
153167
let ret_nofull = try! Matft.linalg.svd(a, full_matrices: false)
154168
XCTAssertEqual((ret_nofull.v *& Matft.diag(v: ret_nofull.s) *& ret_nofull.rt).nearest(), a)
155169
}
156-
170+
157171
do{
158172
let a = MfArray([[1, 2],
159173
[3, 4]])
@@ -167,12 +181,12 @@ final class LinAlgTests: XCTestCase {
167181
[ 0.81741556, -0.57604844]], mftype: .Float))
168182
XCTAssertEqual(ret.v *& ret.v.T, Matft.eye(dim: 2, mftype: .Float))
169183
XCTAssertEqual(ret.rt *& ret.rt.T , Matft.eye(dim: 2, mftype: .Float))
170-
184+
171185
XCTAssertEqual((ret.v *& Matft.diag(v: ret.s) *& ret.rt).nearest(), a)
172186
}
173-
187+
174188
}
175-
189+
176190
func testPInv(){
177191
do{
178192
let a = MfArray([[2, -1, 0],
@@ -182,7 +196,7 @@ final class LinAlgTests: XCTestCase {
182196
[-0.36666667, 0.16666667],
183197
[ 0.08333333, -0.08333333]], mftype: .Float).round(decimals: 5))
184198
}
185-
199+
186200
do{
187201
let a = MfArray([[ 0.10122714, -1.7555435 , 0.72242671],
188202
[ 0.70605646, -3.03520525, -0.8974524 ],
@@ -193,7 +207,7 @@ final class LinAlgTests: XCTestCase {
193207
[-0.24171501, -0.14397516, 0.0288316 , -0.16416708],
194208
[ 0.40742503, -0.2408292 , 0.30600237, 0.23674046]], mftype: .Float).round(decimals: 3))
195209
}
196-
210+
197211
do{
198212
let a = MfArray([[-33, 43, 25],
199213
[-65, -36, -33],
@@ -204,7 +218,7 @@ final class LinAlgTests: XCTestCase {
204218
[ 0.00937469, -0.00758197, 0.00732988, -0.00020324],
205219
[ 0.00565919, -0.00350739, -0.00734483, 0.00663407]], mftype: .Float).round(decimals: 7))
206220
}
207-
221+
208222
do{
209223
let a = MfArray([[7, 2],
210224
[3, 4],
@@ -213,7 +227,7 @@ final class LinAlgTests: XCTestCase {
213227
XCTAssertEqual(ret.round(decimals: 7), MfArray([[ 0.16666667, -0.10606061, 0.03030303],
214228
[-0.16666667, 0.28787879, 0.06060606]], mftype: .Double).round(decimals: 7))
215229
}
216-
230+
217231
do{
218232
let a = MfArray([[ -6, 4, -1, 8, 2],
219233
[ -1, 6, -10, -1, 6],
@@ -226,7 +240,7 @@ final class LinAlgTests: XCTestCase {
226240
[ 0.07609496, 0.03942475, 0.10520211]], mftype: .Float))
227241
}
228242
}
229-
243+
230244
func testPolar(){
231245
do{
232246
let a = MfArray([[1, -1],
@@ -236,13 +250,13 @@ final class LinAlgTests: XCTestCase {
236250
[ 0.51449576, 0.85749293]], mftype: .Float))
237251
XCTAssertEqual(retR.p, MfArray([[ 1.88648444, 1.2004901 ],
238252
[ 1.2004901 , 3.94446746]], mftype: .Float))
239-
253+
240254
let retL = try! Matft.linalg.polar_left(a)
241255
XCTAssertEqual(retL.l, MfArray([[ 0.85749293, -0.51449576],
242256
[ 0.51449576, 0.85749293]], mftype: .Float))
243257
XCTAssertEqual(retL.p, MfArray([[ 1.37198868, -0.34299717],
244258
[-0.34299717, 4.45896321]], mftype: .Float))
245-
259+
246260
}
247261
do{
248262
let a = MfArray([[0.5, 1, 2],
@@ -262,14 +276,16 @@ final class LinAlgTests: XCTestCase {
262276
XCTAssertEqual(retL.p.astype(.Float), MfArray([[1.02156625, 1.98464238, 0.51729779],
263277
[1.98464238, 4.35624892, 2.08189576],
264278
[0.51729779, 2.08189576, 3.55641857]], mftype: .Float))
265-
279+
266280
}
267281
}
268-
282+
#endif
283+
284+
// MARK: - Norm tests (pure Swift, work on WASM)
269285
func testNorm_vec(){
270286
do{
271287
let a = Matft.arange(start: 0, to: 16, by: 1, shape: [2,2,2,2])
272-
288+
273289
XCTAssertEqual(Matft.linalg.normlp_vec(a, ord: 2, axis: 3).round(decimals: 4),
274290
MfArray([[[ 1.0 , 3.60555128],
275291
[ 6.40312424, 9.21954446]],
@@ -282,72 +298,78 @@ final class LinAlgTests: XCTestCase {
282298

283299
[[12.64911064, 13.92838828],
284300
[15.23154621, 16.55294536]]], mftype: .Float).round(decimals: 4))
285-
301+
286302
XCTAssertEqual(Matft.linalg.normlp_vec(a, ord: 0, axis: 0).round(decimals: 4),
287303
MfArray([[[1.0, 2.0],
288304
[2.0, 2.0]],
289305

290306
[[2.0, 2.0],
291307
[2.0, 2.0]]], mftype: .Float).round(decimals: 4))
292-
308+
293309
XCTAssertEqual(Matft.linalg.normlp_vec(a, ord: Float.infinity, axis: -1).round(decimals: 4),
294310
MfArray([[[ 1.0, 3.0],
295311
[ 5.0, 7.0]],
296312

297313
[[ 9.0, 11.0],
298314
[13.0, 15.0]]], mftype: .Float).round(decimals: 4))
299315
}
300-
316+
301317
}
302-
318+
303319
func testNormlp_mat(){
304320
do{
305321
let a = Matft.arange(start: 0, to: 16, by: 1, shape: [2,2,2,2])
322+
// ord=2 requires SVD - only test on non-WASM platforms
323+
#if !os(WASI)
306324
XCTAssertEqual(Matft.linalg.normlp_mat(a, ord: 2, axes: (3, 1)).round(decimals: 4),
307325
MfArray([[ 6.45100985, 9.89123156],
308326
[21.40011829, 25.3372271 ]], mftype: .Float).round(decimals: 4))
309327
XCTAssertEqual(Matft.linalg.normlp_mat(a, ord: 2, axes: (0, -1)).round(decimals: 4),
310328
MfArray([[12.06483816, 15.28810568],
311329
[18.81008019, 22.49163147]], mftype: .Float).round(decimals: 4))
312-
330+
#endif
331+
332+
// ord=-1 and ord=inf use pure Swift operations - work on WASM
313333
XCTAssertEqual(Matft.linalg.normlp_mat(a, ord: -1, axes: (2, 3)).round(decimals: 4),
314334
MfArray([[ 2.0, 10.0],
315335
[18.0, 26.0]], mftype: .Float).round(decimals: 4))
316-
336+
317337
XCTAssertEqual(Matft.linalg.normlp_mat(a, ord: Float.infinity, axes: (-1, 0)).round(decimals: 4),
318338
MfArray([[10.0, 14.0],
319339
[18.0, 22.0]], mftype: .Float).round(decimals: 4))
320340
}
321-
341+
322342
}
323-
343+
324344
func testNormFro_mat(){
325-
345+
326346
do{
327347
let a = Matft.arange(start: 0, to: 16, by: 1, shape: [2,2,2,2])
328-
348+
329349
XCTAssertEqual(Matft.linalg.normfro_mat(a, axes: (2, 0), keepDims: false).round(decimals: 4),
330350
MfArray([[12.9614814 , 14.56021978],
331351
[19.79898987, 21.63330765]], mftype: .Float).round(decimals: 4))
332352
XCTAssertEqual(Matft.linalg.normfro_mat(a, axes: (-2, 1), keepDims: false).round(decimals: 4),
333353
MfArray([[ 7.48331477, 9.16515139],
334354
[22.44994432, 24.41311123]], mftype: .Float).round(decimals: 4))
335-
355+
336356
}
337357
}
338-
358+
359+
// Nuclear norm requires SVD - not available on WASM
360+
#if !os(WASI)
339361
func testNormNuc_mat(){
340362
do{
341363
let a = Matft.arange(start: 0, to: 16, by: 1, shape: [2,2,2,2])
342-
364+
343365
XCTAssertEqual(Matft.linalg.normnuc_mat(a, axes: (2, 0), keepDims: false).round(decimals: 4),
344366
MfArray([[14.14213562, 15.62049935],
345367
[20.59126028, 22.36067977]], mftype: .Float).round(decimals: 4))
346368
XCTAssertEqual(Matft.linalg.normnuc_mat(a, axes: (-2, 1), keepDims: false).round(decimals: 4),
347369
MfArray([[ 8.48528137, 10.0 ],
348370
[22.8035085 , 24.73863375]], mftype: .Float).round(decimals: 4))
349-
371+
350372
}
351373
}
374+
#endif
352375
}
353-
#endif

0 commit comments

Comments
 (0)