Skip to content

Commit a00bed3

Browse files
authored
Merge branch 'main' into fix/macro_structs_generator
2 parents a16fac0 + c96efd5 commit a00bed3

File tree

9 files changed

+414
-75
lines changed

9 files changed

+414
-75
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Elasticsearch"
22
uuid = "e586a49d-aa29-4ce5-8f91-fa4f824367bd"
33
authors = ["Egor Shmorgun <egor.shmorgun@opensesame.com>"]
4-
version = "0.1.5"
4+
version = "0.2.1"
55

66
[deps]
77
CodecZlib = "944b1d66-785c-5afd-91f1-9de20f533193"

src/elastic_transport/ElasticTransport.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ export perform_request
66
include("transport/errors.jl")
77
include("transport/connections/Connections.jl")
88
include("transport/transport.jl")
9+
include("transport/sniffing.jl")
910
include("client.jl")
1011

1112
end

src/elastic_transport/client.jl

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,24 @@ using HTTP
22
using URIs
33
using Mocking
44

5-
const DEFAULT_HOST = "localhost:9200"
6-
7-
struct Client
5+
const DEFAULT_PORT = 9200
6+
const DEFAULT_PROTOCOL = "http"
7+
const DEFAULT_HOST = "localhost"
8+
const DEFAULT_URL = "$DEFAULT_PROTOCOL://$DEFAULT_HOST:$DEFAULT_PORT"
9+
const SECURITY_PRIVILEGES_VALIDATION_WARNING = "The client is unable to verify that the server is Elasticsearch due to security privileges on the server side. Some functionality may not be compatible if the server is running an unsupported product."
10+
const VALIDATION_WARNING = "The client is unable to verify that the server is Elasticsearch. Some functionality may not be compatible if the server is running an unsupported product."
11+
12+
mutable struct Client
813
arguments::Dict
914
options::Dict
1015
hosts::Vector
1116
send_get_body_as::String
12-
ca_fingerpring::Bool
17+
verified::Bool
1318
transport::Transport
1419
end
1520

16-
function Client(arguments::Dict{Symbol,Any}=Dict{Symbol,Any}(); http_client::Module=HTTP)
17-
options = deepcopy(arguments)
21+
function Client(;http_client::Module=HTTP, kwargs...)
22+
options = deepcopy(Dict{Symbol, Any}(kwargs))
1823
arguments = options
1924

2025
get!(options, :reload_connections, false)
@@ -30,7 +35,7 @@ function Client(arguments::Dict{Symbol,Any}=Dict{Symbol,Any}(); http_client::Mod
3035
hosts_config = if !isnothing(host_key_index)
3136
arguments[host_keys[host_key_index]]
3237
else
33-
get(ENV, "ELASTICSEARCH_URL", DEFAULT_HOST)
38+
get(ENV, "ELASTICSEARCH_URL", DEFAULT_URL)
3439
end
3540
hosts = extract_hosts(hosts_config, options)
3641

@@ -50,6 +55,37 @@ function Client(arguments::Dict{Symbol,Any}=Dict{Symbol,Any}(); http_client::Mod
5055
)
5156
end
5257

58+
function verify_elasticsearch(client::Client)
59+
response = nothing
60+
try
61+
response = elastisearch_validation_request(client)
62+
catch exc
63+
if typeof(exc) in [Forbidden, Unauthorized, RequestEntityTooLarge]
64+
client.verified = true
65+
@warn SECURITY_PRIVILEGES_VALIDATION_WARNING
66+
return
67+
else
68+
@warn VALIDATION_WARNING
69+
return
70+
end
71+
end
72+
73+
body = response.body
74+
version = get(() -> Dict(), body, "version") |> version -> get(version, "number", nothing)
75+
76+
verify_with_version_and_headers(client, version, response.headers)
77+
end
78+
79+
@warn "Version verification isn't implemented"
80+
function verify_with_version_and_headers(client::Client, _headers, _version)
81+
@warn "Version verification isn't implemented"
82+
client.verified = true
83+
end
84+
85+
function elastisearch_validation_request(client::Client)
86+
@mock perform_request(client.transport, "GET", "/")
87+
end
88+
5389
function perform_request(
5490
client::Client,
5591
method::String,
@@ -62,23 +98,22 @@ function perform_request(
6298
method = client.send_get_body_as
6399
end
64100

65-
validate_ca_fingerprints(client)
101+
if !client.verified
102+
verify_elasticsearch(client)
103+
end
66104

67105
@mock perform_request(client.transport, method, path; params=params, body=body, headers=headers)
68106
end
69107

70-
@warn "ca fingerprints validation is not implemented"
71-
function validate_ca_fingerprints(::Client)
72-
@warn "ca fingerprints validation is not implemented"
73-
end
74-
75108
function extract_hosts(hosts_config, options)
76109
hosts = if hosts_config isa String
77110
split(hosts_config, ",") .|> strip .|> String
78111
elseif hosts_config isa Vector
79112
hosts_config
80113
elseif hosts_config isa Dict || hosts_config isa URI
81114
[hosts_config]
115+
elseif hosts_config isa NamedTuple
116+
[Dict(zip(keys(hosts_config), values(hosts_config)))]
82117
else
83118
error("Can't extract hosts")
84119
end
@@ -121,8 +156,8 @@ function parse_host_parts(host::String)
121156
host_info = split(host, ":")
122157

123158
Dict(
124-
:host => get(host_info, 0, ""),
125-
:port => get(host_info, 1, "")
159+
:host => get(host_info, 0, DEFAULT_HOST),
160+
:port => parse(Int16, get(host_info, 1, string(DEFAULT_PORT)))
126161
)
127162
end
128163

@@ -131,17 +166,22 @@ end
131166

132167
function parse_host_parts(host::URI)
133168
userinfo = split(host.userinfo, ":") .|> string
169+
port = if isempty(host.port)
170+
DEFAULT_PORT
171+
else
172+
parse(Int16, host.port)
173+
end
134174

135175
Dict(
136176
:scheme => host.scheme,
137177
:user => get(userinfo, 0, ""),
138178
:password => get(userinfo, 1, ""),
139179
:host => host.host,
140180
:path => host.path,
141-
:port => host.port
181+
:port => port
142182
)
143183
end
144184

145-
function parse_host_parts(host::Dict{Symbol,Any})
185+
function parse_host_parts(host::Union{Dict{Symbol,Any}, NamedTuple})
146186
host
147187
end

src/elastic_transport/transport/connections/collection.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,20 @@ function get_connection(collection::Collection)
2424
conn
2525
end
2626

27-
function dead(collection::Collection)
28-
filter(conn -> conn.dead, collection.connections)
27+
function remove!(collection::Collection, conn::Connection)
28+
index = findfirst(==(conn), collection.connections)
29+
30+
if !isnothing(index)
31+
deleteat!(collection.connections, index)
32+
end
2933
end
34+
35+
dead(collection::Collection) = filter(conn -> conn.dead, collection)
36+
37+
Base.length(collection::Collection) = length(collection.connections)
38+
Base.push!(collection::Collection, conns::Connection...) = push!(collection.connections, conns...)
39+
Base.filter(func::Function, collection::Collection) =
40+
Collection(filter(func, collection.connections), collection.selector)
41+
Base.any(func::Function, collection::Collection) = any(func, collection.connections)
42+
Base.foreach(func, collection::Collection) = foreach(func, collection.connections)
43+

src/elastic_transport/transport/connections/connection.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,9 @@ function parse_headers(conn::Connection, headers::Union{Nothing,Dict})
134134
copy(conn.headers)
135135
end
136136
end
137+
138+
function Base.:(==)(src::Connection, other::Connection)
139+
src.host[:protocol] == other.host[:protocol] &&
140+
src.host[:host] == other.host[:host] &&
141+
src.host[:port] == other.host[:port]
142+
end
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
using HTTP
2+
3+
const DEFAULT_SNIFFING_TIMEOUT = 1
4+
const SNIFFING_PROTOCOL = "http"
5+
6+
struct SniffingTimetoutError <: Exception end
7+
8+
function sniff_hosts(transport::Transport)
9+
nodes = perform_sniff_request_with_timeout(transport).body
10+
11+
map(collect(nodes["nodes"])) do (id, info)
12+
if haskey(info, SNIFFING_PROTOCOL)
13+
host, port = parse_publish_address(info[SNIFFING_PROTOCOL]["publish_address"])
14+
15+
Dict(
16+
:id => id,
17+
:name => get(info, "name", nothing),
18+
:version => get(info, "version", nothing),
19+
:host => host,
20+
:port => parse(Int16, port),
21+
:roles => get(info, "roles", nothing),
22+
:attributes => get(info, "attributes", nothing)
23+
)
24+
else
25+
missing
26+
end
27+
end |> skipmissing |> collect
28+
end
29+
30+
function parse_publish_address(publish_address::String)
31+
if !isnothing(match(r"^inet\[.*\]$", publish_address))
32+
parse_address_port(publish_address[begin + 6:end - 1])
33+
elseif !isnothing(match(r"/", publish_address))
34+
parts = split(publish_address, "/") .|> String
35+
36+
[parts[begin], parse_address_port(parts[end])[end]]
37+
else
38+
parse_address_port(publish_address)
39+
end
40+
41+
end
42+
43+
function parse_address_port(publish_address::String)
44+
# If publish address is ipv6
45+
if !isnothing(match(r"[\[\]]", publish_address))
46+
parts = match(r"\A\[(.+)\](?::(\d+))?\z", publish_address)
47+
48+
[parts[1], parts[2]]
49+
else
50+
split(publish_address, ":")
51+
end
52+
end
53+
54+
function perform_sniff_request_with_timeout(transport::Transport)
55+
task = @task(
56+
perform_request(
57+
transport,
58+
"GET",
59+
"/_nodes/$SNIFFING_PROTOCOL",
60+
opts = Dict(:reload_on_failure => false)
61+
)
62+
)
63+
schedule(task)
64+
Timer(DEFAULT_SNIFFING_TIMEOUT) do _timer
65+
istaskdone(task) || Base.throwto(task, SniffingTimetoutError())
66+
end
67+
68+
try
69+
fetch(task)
70+
catch
71+
throw(task.exception)
72+
end
73+
end
74+
75+
function sniffing_timeout(transport::Transport)
76+
get(transport.options, :sniffing_timeout, DEFAULT_SNIFFING_TIMEOUT)
77+
end

src/elastic_transport/transport/transport.jl

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ using Retry
55
using Mocking
66
using JSON
77

8-
const DEFAULT_PORT = 9200
9-
const DEFAULT_PROTOCOL = "http"
108
const DEFAULT_RELOAD_AFTER = 10_000 # Requests
119
const DEFAULT_RESURRECT_AFTER = 60 # Seconds
1210
const DEFAULT_MAX_RETRIES = 3 # Requests
@@ -62,7 +60,7 @@ function get_connection(transport::Transport, options=Dict())
6260
end
6361

6462
if transport.reload_connections && (transport.counter % transport.reload_after) == 0
65-
reload_connections!
63+
reload_connections!(transport)
6664
end
6765

6866
Connections.get_connection(transport.connections)
@@ -72,9 +70,18 @@ function resurrect_dead_connections!(transport::Transport)
7270
foreach(Connections.resurrect!, Connections.dead(transport.connections))
7371
end
7472

75-
@warn "Reload connections are not implemented"
73+
7674
function reload_connections!(transport::Transport)
77-
@warn "Reload connections are not implemented"
75+
try
76+
hosts = sniff_hosts(transport)
77+
rebuild_connections!(transport, hosts = hosts)
78+
catch e
79+
if e isa SniffingTimetoutError
80+
@error "[SnifferTimeoutError] Timeout when reloading connections."
81+
else
82+
throw(e)
83+
end
84+
end
7885
end
7986

8087
function build_connections(hosts::Vector, options::Dict)
@@ -84,6 +91,23 @@ function build_connections(hosts::Vector, options::Dict)
8491
)
8592
end
8693

94+
function rebuild_connections!(transport::Transport; hosts)
95+
lock(transport.state_lock) do
96+
transport.hosts = hosts
97+
98+
new_connections = build_connections(hosts, transport.options)
99+
stale_connections = filter(transport.connections.connections) do conn
100+
!any(new_conn -> new_conn == conn, new_connections)
101+
end
102+
new_connections = filter(new_connections) do conn
103+
!any(new_conn -> new_conn == conn, transport.connections.connections)
104+
end
105+
106+
foreach(conn -> Connections.remove!(transport.connections, conn), stale_connections)
107+
push!(transport.connections, new_connections.connections...)
108+
end
109+
end
110+
87111
function connections_from_host(hosts::Vector, options::Dict)
88112
map(hosts) do host
89113
host[:protocol] = get(host, :scheme) do
@@ -169,8 +193,7 @@ function perform_request(
169193
if typeof(exception) in HOST_UNREACHABLE_EXCEPTIONS
170194
@error "[$(typeof(exception))] $(connection.host)"
171195

172-
# Disable dead connections, before reload_connections implementation
173-
# Connections.dead!(connection)
196+
Connections.dead!(connection)
174197
end
175198

176199
@retry if reload_on_failure && tries < length(transport.connections) && in(typeof(exception), HOST_UNREACHABLE_EXCEPTIONS)

0 commit comments

Comments
 (0)