Skip to content

Commit bbdf078

Browse files
authored
[Nonlinear] add support for atan2 (#1987)
1 parent 9ba623e commit bbdf078

File tree

5 files changed

+93
-3
lines changed

5 files changed

+93
-3
lines changed

src/FileFormats/NL/NLExpr.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ const _JULIA_TO_AMPL = Dict{Symbol,Int}(
6262
:cosh => OP_cosh,
6363
:cos => OP_cos,
6464
:atanh => OP_atanh,
65-
# OP_atan2 = 48,
65+
# Because :atan is also a univariate operator, this is special cased in
66+
# process_expr!.
67+
# :atan = OP_atan2,
6668
:atan => OP_atan,
6769
:asinh => OP_asinh,
6870
:asin => OP_asin,
@@ -146,7 +148,7 @@ const _AMPL_TO_JULIA = Dict{Int,Tuple{Int,Function}}(
146148
OP_cosh => (1, cosh),
147149
OP_cos => (1, cos),
148150
OP_atanh => (1, atanh),
149-
# OP_atan2 = 48,
151+
OP_atan2 => (2, atan),
150152
OP_atan => (1, atan),
151153
OP_asinh => (1, asinh),
152154
OP_asin => (1, asin),
@@ -431,6 +433,11 @@ function _process_expr!(expr::_NLExpr, args::Vector{Any})
431433
if opcode === nothing
432434
error("Unsupported operation $(op)")
433435
end
436+
if op == :atan && N == 2
437+
# Special case binary use of atan, because Julia uses method overloading
438+
# instead of having an explicit atan2 function.
439+
opcode = OP_atan2
440+
end
434441
push!(expr.nonlinear_terms, opcode)
435442
if opcode in _NARY_OPCODES
436443
push!(expr.nonlinear_terms, N)

src/Nonlinear/operators.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ const DEFAULT_UNIVARIATE_OPERATORS = first.(SYMBOLIC_UNIVARIATE_EXPRESSIONS)
8282
8383
The list of multivariate operators that are supported by default.
8484
"""
85-
const DEFAULT_MULTIVARIATE_OPERATORS = [:+, :-, :*, :^, :/, :ifelse]
85+
const DEFAULT_MULTIVARIATE_OPERATORS = [:+, :-, :*, :^, :/, :ifelse, :atan]
8686

8787
"""
8888
OperatorRegistry()
@@ -544,6 +544,9 @@ function eval_multivariate_function(
544544
elseif op == :ifelse
545545
@assert length(x) == 3
546546
return ifelse(Bool(x[1]), x[2], x[3])
547+
elseif op == :atan
548+
@assert length(x) == 2
549+
return atan(x[1], x[2])
547550
end
548551
id = registry.multivariate_operator_to_id[op]
549552
offset = id - registry.multivariate_user_operator_start
@@ -619,6 +622,11 @@ function eval_multivariate_gradient(
619622
g[1] = zero(T) # It doesn't matter what this is.
620623
g[2] = x[1] == one(T)
621624
g[3] = x[1] == zero(T)
625+
elseif op == :atan
626+
@assert length(x) == 2
627+
base = x[1]^2 + x[2]^2
628+
g[1] = x[2] / base
629+
g[2] = -x[1] / base
622630
else
623631
id = registry.multivariate_operator_to_id[op]
624632
offset = id - registry.multivariate_user_operator_start
@@ -707,6 +715,18 @@ function eval_multivariate_hessian(
707715
d = 1 / x[2]^2
708716
H[2, 1] = -d
709717
H[2, 2] = 2 * x[1] * d / x[2]
718+
elseif op == :atan
719+
# f(x) = atan(y, x)
720+
#
721+
# ∇f(x) = +x/(x^2+y^2)
722+
# -y/(x^2+y^2)
723+
#
724+
# ∇²(x) = -(2xy)/(x^2+y^2)^2
725+
# (y^2-x^2)/(x^2+y^2)^2 (2xy)/(x^2+y^2)^2
726+
base = (x[1]^2 + x[2]^2)^2
727+
H[1, 1] = -2 * x[2] * x[1] / base
728+
H[2, 1] = (x[1]^2 - x[2]^2) / base
729+
H[2, 2] = 2 * x[2] * x[1] / base
710730
else
711731
id = registry.multivariate_operator_to_id[op]
712732
offset = id - registry.multivariate_user_operator_start

test/FileFormats/NL/NL.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,19 @@ function test_nlexpr_binary_specialcase()
212212
)
213213
end
214214

215+
function test_nlexpr_atan_and_atan2()
216+
x = MOI.VariableIndex(1)
217+
y = MOI.VariableIndex(2)
218+
_test_nlexpr(:(atan($x)), [NL.OP_atan, x], Dict(x => 0.0), 0.0)
219+
_test_nlexpr(
220+
:(atan($x, $y)),
221+
[NL.OP_atan2, x, y],
222+
Dict(x => 0.0, y => 0.0),
223+
0.0,
224+
)
225+
return
226+
end
227+
215228
function test_nlexpr_unsupportedoperation()
216229
x = MOI.VariableIndex(1)
217230
err = ErrorException("Unsupported operation foo")

test/FileFormats/NL/read.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,28 @@ function test_parse_header_binary()
118118
return
119119
end
120120

121+
function test_parse_expr_atan2()
122+
model = NL._CacheModel()
123+
io = IOBuffer()
124+
write(io, "o48\nv0\nv1\n")
125+
seekstart(io)
126+
x = MOI.VariableIndex.(1:2)
127+
@test NL._parse_expr(io, model) == :(atan($(x[1]), $(x[2])))
128+
@test eof(io)
129+
return
130+
end
131+
132+
function test_parse_expr_atan()
133+
model = NL._CacheModel()
134+
io = IOBuffer()
135+
write(io, "o49\nv0\n")
136+
seekstart(io)
137+
x = MOI.VariableIndex.(1:1)
138+
@test NL._parse_expr(io, model) == :(atan($(x[1])))
139+
@test eof(io)
140+
return
141+
end
142+
121143
function test_parse_header_assertion_errors()
122144
model = NL._CacheModel()
123145
for header in [

test/Nonlinear/Nonlinear.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -842,6 +842,34 @@ function test_NLPBlockData()
842842
return
843843
end
844844

845+
function test_parse_atan2()
846+
model = Nonlinear.Model()
847+
x = MOI.VariableIndex(1)
848+
y = MOI.VariableIndex(2)
849+
θ = π / 4
850+
Nonlinear.add_constraint(model, :(atan($x, $y)), MOI.LessThan(θ))
851+
evaluator = Nonlinear.Evaluator(model)
852+
MOI.initialize(evaluator, [:ExprGraph])
853+
@test MOI.constraint_expr(evaluator, 1) == :(atan(x[$x], x[$y]) <= $θ)
854+
return
855+
end
856+
857+
function test_eval_atan2()
858+
r = Nonlinear.OperatorRegistry()
859+
x = [1.1, 2.2]
860+
@test Nonlinear.eval_multivariate_function(r, :atan, x) atan(x[1], x[2])
861+
g = zeros(2)
862+
Nonlinear.eval_multivariate_gradient(r, :atan, g, x)
863+
@test g[1] x[2] / (x[1]^2 + x[2]^2)
864+
@test g[2] -x[1] / (x[1]^2 + x[2]^2)
865+
H = LinearAlgebra.LowerTriangular(zeros(2, 2))
866+
@test Nonlinear.eval_multivariate_hessian(r, :atan, H, x)
867+
@test H[1, 1] -2 * x[2] * x[1] / (x[1]^2 + x[2]^2)^2
868+
@test H[2, 1] (x[1]^2 - x[2]^2) / (x[1]^2 + x[2]^2)^2
869+
@test H[2, 2] 2 * x[2] * x[1] / (x[1]^2 + x[2]^2)^2
870+
return
871+
end
872+
845873
end
846874

847875
TestNonlinear.runtests()

0 commit comments

Comments
 (0)