Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SourceCodeMcCormick"
uuid = "a7283dc5-4ecf-47fb-a95b-1412723fc960"
authors = ["Robert Gottlieb <Robert.x.gottlieb@uconn.edu>"]
version = "0.5.0"
version = "0.5.1"

[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Expand Down
101 changes: 86 additions & 15 deletions src/kernel_writer/kernel_write.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ kgen(num::Num, raw_outputs::Vector{Symbol}; constants::Vector{Num}=Num[], overwr
kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}; constants::Vector{Num}=Num[], overwrite::Bool=false, splitting::Symbol=:default, affine_quadratic::Bool=true) = kgen(num, gradlist, raw_outputs, constants, overwrite, splitting, affine_quadratic)
function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, constants::Vector{Num}, overwrite::Bool, splitting::Symbol, affine_quadratic::Bool)
# Create a hash of the expression and check if the function already exists
expr_hash = string(hash(num+sum(gradlist)), base=62)
expr_hash = string(hash(string(num)*string(gradlist)), base=62)
if (overwrite==false) && (isfile(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl")))
try func_name = eval(Meta.parse("f_"*expr_hash))
return func_name
Expand Down Expand Up @@ -102,9 +102,6 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
elseif splitting==:high # Formerly default
split_point = 1500
max_size = 2000
# elseif splitting==:high # More splitting
# split_point = 1000
# max_size = 1200
elseif splitting==:max # Extremely small
split_point = 500
max_size = 750
Expand All @@ -116,7 +113,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
sparsity = detect_sparsity(factored, gradlist)

# Decide if the kernel needs to be split
if (n_vars[end] < 31) && (n_lines[end] <= max_size)
if (n_vars[end] < 31) && ((n_lines[end] <= max_size) || (findfirst(x -> x > split_point, n_lines)==length(n_lines)))
# Complexity is fairly low; only a single kernel needed
create_kernel!(expr_hash, 1, num, get_name.(gradlist), func_outputs, constants, factored, sparsity)
push!(kernel_nums, 1)
Expand All @@ -130,7 +127,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
while !complete
# Determine which line to break at
line_ID = findfirst(x -> x > split_point, n_lines)
vars_ID = findfirst(x -> x == 31, n_vars)
vars_ID = findfirst(x -> (x == 30) || (x == 31), n_vars)
if isnothing(vars_ID)
new_ID = line_ID
elseif isnothing(line_ID)
Expand Down Expand Up @@ -188,7 +185,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
n_lines = complexity(factored)
n_vars = var_counts(factored)

# If the total number of lines (not including the final line) is below 2000
# If the total number of lines (not including the final line) is below the max size
# and the number of variables is below 32, we can make the final kernel and be done
if (n_vars[end] < 32) && (all(n_lines[1:end-1] .<= max_size))
create_kernel!(expr_hash, kernel_count, extract(factored), get_name.(gradlist), func_outputs, constants, factored, sparsity)
Expand Down Expand Up @@ -328,7 +325,12 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num
file = open(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"), "a")

# Put in the preamble.
write(file, preamble_string(expr_hash, ["OUT"; string.(vars)], 1, 1, length(gradlist)))
if isempty(vars)
write(file, preamble_string(expr_hash, ["OUT";], 1, 1, length(gradlist)))
else
write(file, preamble_string(expr_hash, ["OUT"; string.(vars)], 1, 1, length(gradlist)))
end


# Depending on the format of the expression, compose the kernel differently
if typeof(expr) <: Real
Expand Down Expand Up @@ -360,9 +362,9 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num
end
end
else # There must be two elements in the dictionary
binary_vars = string.(get_name.(keys(key.dict)))
binary_vars = string.(get_name.(keys(expr.dict)))
binary_vars = binary_vars[sort_vars(binary_vars)]
write(file, SCMC_quadaff_binary(vars..., expr.coeff, varlist))
write(file, SCMC_quadaff_binary(binary_vars..., expr.coeff, varlist))
end

elseif exprtype(expr)==ADD
Expand Down Expand Up @@ -394,7 +396,13 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num
# EAGO already does this and bypasses the need to calculate relaxations.
# But, for compatibility with McCormick-style relaxations in ParBB,
# it's easier to simply calculate what ParBB is expecting.)
write(file, postamble_quadaff(string.(vars), varlist))
if isempty(varlist)
write(file, postamble_quadaff(String[], String[]))
elseif isempty(vars)
write(file, postamble_quadaff(String[], varlist))
else
write(file, postamble_quadaff(string.(vars), varlist))
end
close(file)

# Include this kernel so SCMC knows what it is
Expand All @@ -403,7 +411,13 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num
# Add onto the file the "main" CPU function that calls the kernel
blocks = Int32(CUDA.attribute(CUDA.device(), CUDA.DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT))
file = open(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"), "a")
write(file, outro(expr_hash, [1], [string.(vars)], ["OUT"], blocks, get_name.(gradlist)))
if isempty(gradlist)
write(file, outro(expr_hash, [1], [String[]], ["OUT"], blocks, Symbol[]))
elseif isempty(vars)
write(file, outro(expr_hash, [1], [String[]], ["OUT"], blocks, get_name.(gradlist)))
else
write(file, outro(expr_hash, [1], [string.(vars)], ["OUT"], blocks, get_name.(gradlist)))
end
close(file)

# Include the file again to get the final kernel
Expand Down Expand Up @@ -731,6 +745,7 @@ end
# 7) log(inv(x1)) = -log(x1) [EAGO paper]
# 8) CONST1*CONST2*x1 = (CONST1*CONST2)*x1
# 9) 1 / (1 + exp(-x)) = Sigmoid(x)
# 10) sin(x) = cos(x - pi/2)
#
# Forms that aren't relevant yet:
# 1) (a^x1)^b = (a^b)^x1 [EAGO paper] (Can't do powers besides integers)
Expand Down Expand Up @@ -826,7 +841,7 @@ function perform_substitutions(old_factored::Vector{Equation})
end
end
# Create a factorization of this new expr
new_factorization = factor(new_expr)
new_factorization = factor(new_expr, split_div=true)
# Scan through the new factorization to see if we can merge elements
# with the original factored list
done = false
Expand Down Expand Up @@ -1191,7 +1206,7 @@ function perform_substitutions(old_factored::Vector{Equation})
new_expr *= arg
end
# Create a factorization of this new expr
new_factorization = factor(new_expr)
new_factorization = factor(new_expr, split_div=true)


# Scan through the new factorization to see if we can merge elements
Expand Down Expand Up @@ -1315,6 +1330,38 @@ function perform_substitutions(old_factored::Vector{Equation})
end
end
end

# 10) sin(x) = cos(x - pi/2)
if exprtype(factored[index0].rhs)==TERM
if factored[index0].rhs.f==sin
# We found sin(arg). Check if (arg - pi/2) exists,
# and if so, also check if cos(arg - pi/2) exists.
scan_flag = true
index1 = findfirst(x -> isequal(x.rhs, arguments(factored[index0].rhs)[] - pi/2), factored)
if !isnothing(index1)
index2 = findfirst(x -> isequal(x.rhs, cos(factored[index1].lhs)), factored)
if !isnothing(index2)
# cos(arg - pi/2) exists already (index2). Remove all reference to index0 and replace with index2
for i in eachindex(factored)
@eval $factored[$i] = $factored[$i].lhs ~ substitute($factored[$i].rhs, Dict($factored[$index0].lhs => $factored[$index2].lhs))
end
deleteat!(factored, index0)
else
# arg - pi/2 exists already (index1), but not cos(arg - pi/2). Change
# index0 to be cos of index1.lhs instead of sin of arg
@eval $factored[$index0] = $factored[$index0].lhs ~ cos($factored[$index1].lhs)
end
else
# (arg - pi/2) doesn't exist, so we need to create it
newsym = gensym(:aux)
newsym = Symbol(string(newsym)[3:5] * string(newsym)[7:end])
newvar = genvar(newsym)
insert!(factored, index0, Equation(Symbolics.value(newvar), arguments(factored[index0].rhs)[] - pi/2))
@eval $factored[$index0+1] = $factored[$index0+1].lhs ~ cos($newvar)
end
break
end
end
end
end

Expand Down Expand Up @@ -1511,6 +1558,10 @@ function write_operation(file::IOStream, RHS::BasicSymbolic{Real}, inputs::Vecto
write(file, SCMC_sigmoid_kernel(inputs..., gradlist, sparsity))
elseif RHS.f==sqrt
write(file, SCMC_float_power_kernel(inputs..., 0.5, gradlist, sparsity))
elseif RHS.f==cos
write(file, SCMC_cos_kernel(inputs..., gradlist, sparsity))
elseif RHS.f==abs
write(file, SCMC_abs_kernel(inputs..., gradlist, sparsity))
else
close(file)
error("Some function was used that we can't handle yet ($RHS)")
Expand Down Expand Up @@ -1845,6 +1896,10 @@ function _complexity(complexity::Vector{Int}, factorized::Vector{Equation}, star
else
total_lines += 190
end
new_ID = findfirst(x -> isequal(x.lhs, RHS.base), factorized)
if !isnothing(new_ID)
total_lines += _complexity(complexity, factorized, new_ID)
end
elseif exprtype(RHS) == TERM
if RHS.f==exp
total_lines += 212 # Ranges from 212--310
Expand All @@ -1866,8 +1921,24 @@ function _complexity(complexity::Vector{Int}, factorized::Vector{Equation}, star
end
elseif RHS.f==sqrt
total_lines += 190
new_ID = findfirst(x -> isequal(x.lhs, RHS.arguments[1]), factorized)
if !isnothing(new_ID)
total_lines += _complexity(complexity, factorized, new_ID)
end
elseif RHS.f==cos || RHS.f==sin
total_lines += 300
new_ID = findfirst(x -> isequal(x.lhs, RHS.arguments[1]), factorized)
if !isnothing(new_ID)
total_lines += _complexity(complexity, factorized, new_ID)
end
elseif RHS.f==abs
total_lines += 280
new_ID = findfirst(x -> isequal(x.lhs, RHS.arguments[1]), factorized)
if !isnothing(new_ID)
total_lines += _complexity(complexity, factorized, new_ID)
end
else
error("Unknown function")
error("Some function was used that we can't handle yet ($RHS)")
end
elseif exprtype(RHS) == SYM
nothing
Expand Down
Loading
Loading