Skip to content

Commit 5c6ce35

Browse files
committed
Loosen other type requirements
1 parent 7f68448 commit 5c6ce35

File tree

4 files changed

+10
-10
lines changed

4 files changed

+10
-10
lines changed

benchmark/benchmark_utils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ function random_node(tree::Node{T})::Node{T} where {T}
2525
return random_node(tree.r)
2626
end
2727

28-
function make_random_leaf(nfeatures::Int, ::Type{T})::Node{T} where {T}
28+
function make_random_leaf(nfeatures::Integer, ::Type{T})::Node{T} where {T}
2929
if rand() > 0.5
3030
return Node(; val=randn(T))
3131
else
@@ -35,7 +35,7 @@ end
3535

3636
# Add a random unary/binary operation to the end of a tree
3737
function append_random_op(
38-
tree::Node{T}, operators, nfeatures::Int; makeNewBinOp::Union{Bool,Nothing}=nothing
38+
tree::Node{T}, operators, nfeatures::Integer; makeNewBinOp::Union{Bool,Nothing}=nothing
3939
)::Node{T} where {T}
4040
nuna = length(operators.unaops)
4141
nbin = length(operators.binops)
@@ -64,7 +64,7 @@ function append_random_op(
6464
end
6565

6666
function gen_random_tree_fixed_size(
67-
node_count::Int, operators, nfeatures::Int, ::Type{T}
67+
node_count::Integer, operators, nfeatures::Integer, ::Type{T}
6868
)::Node{T} where {T}
6969
tree = make_random_leaf(nfeatures, T)
7070
cur_size = count_nodes(tree)

ext/DynamicExpressionsSymbolicUtilsExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ function parse_tree_to_eqs(
2727
return SymbolicUtils.Sym{LiteralReal}(Symbol("x$(tree.feature)"))
2828
end
2929
# Collect the next children
30-
children = tree.degree >= 2 ? (tree.l, tree.r) : (tree.l,)
30+
children = tree.degree == 2 ? (tree.l, tree.r) : (tree.l,)
3131
# Get the operation
32-
op = tree.degree > 1 ? operators.binops[tree.op] : operators.unaops[tree.op]
32+
op = tree.degree == 2 ? operators.binops[tree.op] : operators.unaops[tree.op]
3333
# Create an N tuple of Numbers for each argument
3434
dtypes = map(x -> Number, 1:(tree.degree))
3535
#

src/EquationUtils.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,25 +92,25 @@ end
9292
# This will mirror a Node struct, rather
9393
# than adding a new attribute to Node.
9494
mutable struct NodeIndex
95-
constant_index::Int16 # Index of this constant (if a constant exists here)
95+
constant_index::UInt16 # Index of this constant (if a constant exists here)
9696
l::NodeIndex
9797
r::NodeIndex
9898

9999
NodeIndex() = new()
100100
end
101101

102102
function index_constants(tree::Node)::NodeIndex
103-
return index_constants(tree, Int16(0))
103+
return index_constants(tree, UInt16(0))
104104
end
105105

106-
function index_constants(tree::Node, left_index::Int16)::NodeIndex
106+
function index_constants(tree::Node, left_index)::NodeIndex
107107
index_tree = NodeIndex()
108108
index_constants!(tree, index_tree, left_index)
109109
return index_tree
110110
end
111111

112112
# Count how many constants to the left of this node, and put them in a tree
113-
function index_constants!(tree::Node, index_tree::NodeIndex, left_index::Int16)
113+
function index_constants!(tree::Node, index_tree::NodeIndex, left_index)
114114
if tree.degree == 0
115115
if tree.constant
116116
index_tree.constant_index = left_index + 1

src/EvaluateEquationDerivative.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ function eval_grad_tree_array(
200200
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Number}
201201
assert_autodiff_enabled(operators)
202202
n_gradients = variable ? size(cX, 1) : count_constants(tree)
203-
index_tree = index_constants(tree, Int16(0))
203+
index_tree = index_constants(tree, UInt16(0))
204204
return eval_grad_tree_array(
205205
tree,
206206
Val(n_gradients),

0 commit comments

Comments
 (0)