Skip to content

Commit d562486

Browse files
committed
uses LIBSVM instead of LIBLINEAR
1 parent b433dbb commit d562486

5 files changed

Lines changed: 66 additions & 58 deletions

File tree

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ version = "0.6.0"
66
[deps]
77
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
88
InvertedFiles = "b20bd276-2c70-11ec-161a-3d1e1109a1c3"
9-
LIBLINEAR = "2d691ee1-e668-5016-a719-b2531b85e0f5"
9+
LIBSVM = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1111
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1212
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
@@ -20,7 +20,7 @@ TextSearch = "7f6f6c8a-3b03-11e9-223d-e7d88259bd6c"
2020

2121
[compat]
2222
InvertedFiles = "0.8"
23-
LIBLINEAR = "0.7"
23+
LIBSVM = "0.8"
2424
MLUtils = "0.4"
2525
Parameters = "0.12"
2626
SearchModels = "0.4"

src/TextClassification.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using Parameters, InvertedFiles
99
include("scores.jl")
1010
include("textconfigspace.jl")
1111
include("textmodelspace.jl")
12-
include("liblinearconfig.jl")
12+
include("libsvmconfig.jl")
1313
include("microtcconfig.jl")
1414
include("microtc.jl")
1515
include("utils.jl")

src/liblinearconfig.jl

Lines changed: 0 additions & 53 deletions
This file was deleted.

src/libsvmconfig.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# This file is part of TextClassification.jl
2+
3+
export LIBSVMConfig, LIBSVMConfigSpace
4+
5+
using LIBSVM
6+
7+
@with_kw struct LIBSVMConfig
8+
C::Float64 = 1.0
9+
weights = :balance
10+
end
11+
12+
struct LIBSVMWrapper{LIBSVMModel}
13+
dim::Int
14+
cls::LIBSVMModel
15+
end
16+
17+
function balanced_weights(y)
18+
C = countmap(y)
19+
s = sum(values(C))
20+
nc = length(C)
21+
Dict{Any,Float64}(label => (s / (nc * count)) for (label, count) in C)
22+
end
23+
24+
function create(config::LIBSVMConfig, train_X, train_y, dim)
25+
train_X_ = sparse(train_X, dim)
26+
nt = Threads.nthreads()
27+
verbose = true
28+
kernel = Kernel.Linear
29+
weights = balanced_weights(train_y)
30+
cls = svmtrain(train_X_, train_y; nt, weights, verbose, kernel, cost=config.C, )
31+
LIBSVMWrapper(dim, cls)
32+
end
33+
34+
@with_kw struct LIBSVMConfigSpace <: AbstractSolutionSpace
35+
C = [1.0]
36+
eps = [0.1]
37+
weights = [:balance, nothing]
38+
scale_C = (lower=0.001, s=3.0, upper=1000.0)
39+
scale_eps = (lower=0.0001, s=3.0, upper=0.99)
40+
end
41+
42+
Base.eltype(::LIBSVMConfigSpace) = LIBSVMConfig
43+
44+
function Base.rand(space::LIBSVMConfigSpace)
45+
LIBSVMConfig(rand(space.C), rand(space.weights))
46+
end
47+
48+
function combine(a::LIBSVMConfig, b::LIBSVMConfig)
49+
LIBSVMConfig(a.C, rand([a.weights, b.weights]))
50+
end
51+
52+
function mutate(space::LIBSVMConfigSpace, a::LIBSVMConfig, iter)
53+
C = space.scale_C === nothing ? a.C : SearchModels.scale(a.C; space.scale_C...)
54+
weights = rand([a.weights, rand(space.weights)])
55+
LIBSVMConfig(C, weights)
56+
end
57+
58+
function predict(w::LIBSVMWrapper, vec::SVEC)
59+
ypred = svmpredict(w.cls, sparse([vec], w.dim); nt=Threads.nthreads())
60+
ypred[1][1]
61+
end

src/microtcconfig.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ end
1010
function MicroTC_Config(;
1111
textconfig=TextConfig(),
1212
textmodel=EntModelConfig(),
13-
cls=LiblinearConfig(1.0, 0.1)
13+
cls=LIBSVMConfig(1.0, 0.1)
1414
)
1515

1616
MicroTC_Config(textconfig, textmodel, cls)
@@ -38,7 +38,7 @@ Base.isequal(a::MicroTC_Config, b::MicroTC_Config) = repr(a) == repr(b)
3838

3939
function MicroTC_ConfigSpace(;
4040
textmodel=[VectorModelConfigSpace(), EntModelConfigSpace()],
41-
cls=[LiblinearConfigSpace()],
41+
cls=[LIBSVMConfigSpace()],
4242
textconfig::TextConfigSpace = TextConfigSpace()
4343
)
4444

0 commit comments

Comments
 (0)