Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 224 additions & 18 deletions grpc_client/lib/grpc/client/connection.ex
Original file line number Diff line number Diff line change
Expand Up @@ -80,22 +80,31 @@ defmodule GRPC.Client.Connection do
@insecure_scheme "http"
@secure_scheme "https"
@refresh_interval 15_000
@default_resolve_interval 30_000
@default_max_resolve_interval 300_000
@default_min_resolve_interval 5_000

@type t :: %__MODULE__{
virtual_channel: Channel.t(),
real_channels: %{String.t() => {:ok, Channel.t()} | {:error, any()}},
real_channels: %{String.t() => {:connected, Channel.t()} | {:failed, any()}},
lb_mod: module() | nil,
lb_state: term() | nil,
resolver: module() | nil,
adapter: module()
adapter: module(),
resolver_target: String.t() | nil,
connect_opts: keyword(),
dns_resolver_pid: pid() | nil
}

defstruct virtual_channel: nil,
real_channels: %{},
lb_mod: nil,
lb_state: nil,
resolver: nil,
adapter: GRPC.Client.Adapters.Gun
adapter: GRPC.Client.Adapters.Gun,
resolver_target: nil,
connect_opts: [],
dns_resolver_pid: nil

def child_spec(initial_state) do
%{
Expand All @@ -121,6 +130,26 @@ defmodule GRPC.Client.Connection do
)

Process.send_after(self(), :refresh, @refresh_interval)

# Only start periodic re-resolution for DNS targets — static targets
# (ipv4:, ipv6:, unix:) always resolve to the same addresses.
state =
if state.resolver && state.resolver_target && dns_target?(state.resolver_target) do
{:ok, pid} =
GRPC.Client.DNSResolver.start_link(
connection_pid: self(),
resolver: state.resolver,
target: state.resolver_target,
resolve_interval: state.connect_opts[:resolve_interval],
max_resolve_interval: state.connect_opts[:max_resolve_interval],
min_resolve_interval: state.connect_opts[:min_resolve_interval]
)

%{state | dns_resolver_pid: pid}
else
state
end

{:ok, state}
end

Expand All @@ -140,6 +169,9 @@ defmodule GRPC.Client.Connection do
* `:codec` – request/response codec (default: `GRPC.Codec.Proto`)
* `:compressor` / `:accepted_compressors` – message compression
* `:headers` – default metadata headers
* `:resolve_interval` – DNS re-resolution interval in ms (default: 30000)
* `:max_resolve_interval` – backoff cap in ms (default: 300000)
* `:min_resolve_interval` – rate-limit floor in ms (default: 5000)

Returns:

Expand Down Expand Up @@ -227,14 +259,33 @@ defmodule GRPC.Client.Connection do
end
end

@doc """
Triggers an immediate DNS re-resolution, subject to rate limiting.

Intended for use by health checks or heartbeat mechanisms that detect
a backend has gone away and want to force a fresh DNS lookup.
"""
@spec resolve_now(Channel.t()) :: :ok
def resolve_now(%Channel{ref: ref}) do
GenServer.cast(via(ref), :resolve_now)
end

@impl GenServer
def handle_cast(:resolve_now, %{dns_resolver_pid: pid} = state) when is_pid(pid) do
send(pid, :resolve_now)
{:noreply, state}
end

def handle_cast(:resolve_now, state), do: {:noreply, state}

@impl GenServer
def handle_call({:disconnect, %Channel{adapter: adapter} = channel}, _from, state) do
resp = {:ok, %Channel{channel | adapter_payload: %{conn_pid: nil}}}
:persistent_term.erase({__MODULE__, :lb_state, channel.ref})

if Map.has_key?(state, :real_channels) do
Enum.map(state.real_channels, fn
{_key, {:ok, ch}} ->
{_key, {:connected, ch}} ->
do_disconnect(adapter, ch)

_ ->
Expand Down Expand Up @@ -262,23 +313,31 @@ defmodule GRPC.Client.Connection do
channel_key = build_address_key(prefer_host, prefer_port)

case Map.get(channels, channel_key) do
nil ->
Logger.warning("LB picked #{channel_key}, but no channel found in pool")
{:connected, %Channel{} = picked_channel} ->
:persistent_term.put({__MODULE__, :lb_state, vc.ref}, picked_channel)

Process.send_after(self(), :refresh, @refresh_interval)
{:noreply, %{state | lb_state: new_lb_state}}
{:noreply, %{state | lb_state: new_lb_state, virtual_channel: picked_channel}}

{:ok, %Channel{} = picked_channel} ->
:persistent_term.put({__MODULE__, :lb_state, vc.ref}, picked_channel)
_nil_or_failed ->
# LB picked a channel that is missing or in {:failed, _} state.
# Don't update persistent_term — keep serving from the current
# virtual_channel until re-resolution provides healthy backends.
Logger.warning("LB picked #{channel_key}, but channel is unavailable")

Process.send_after(self(), :refresh, @refresh_interval)

{:noreply, %{state | lb_state: new_lb_state, virtual_channel: picked_channel}}
{:noreply, %{state | lb_state: new_lb_state}}
end
end

def handle_info(:refresh, state), do: {:noreply, state}

# Result from the dedicated DNSResolver process
def handle_info({:dns_result, result}, state) do
state = handle_resolve_result(result, state)
{:noreply, state}
end

def handle_info({:DOWN, _ref, :process, pid, reason}, state) do
Logger.warning(
"#{inspect(__MODULE__)} received :DOWN from #{inspect(pid)} with reason: #{inspect(reason)}"
Expand Down Expand Up @@ -308,6 +367,146 @@ defmodule GRPC.Client.Connection do

def terminate(_reason, _state), do: :ok

defp handle_resolve_result({:ok, %{addresses: []}}, state), do: state

defp handle_resolve_result({:ok, %{addresses: new_addresses}}, state) do
reconcile_channels(new_addresses, state.adapter, state.connect_opts, state)
end

defp handle_resolve_result({:error, _reason}, state), do: state

defp reconcile_channels(new_addresses, adapter, opts, state) do
new_keys = MapSet.new(new_addresses, &build_address_key(&1.address, &1.port))
old_keys = MapSet.new(Map.keys(state.real_channels))

added = MapSet.difference(new_keys, old_keys)
removed = MapSet.difference(old_keys, new_keys)

real_channels = disconnect_removed_channels(removed, adapter, state.real_channels)
real_channels = connect_new_channels(new_addresses, added, adapter, opts, state, real_channels)
rebalance_after_reconcile(new_addresses, real_channels, state)
end

defp disconnect_removed_channels(removed, adapter, real_channels) do
Enum.reduce(MapSet.to_list(removed), real_channels, fn key, channels ->
case Map.get(channels, key) do
{:connected, ch} -> do_disconnect(adapter, ch)
_ -> :ok
end

Map.delete(channels, key)
end)
end

defp connect_new_channels(new_addresses, added, adapter, opts, state, real_channels) do
Enum.reduce(new_addresses, real_channels, fn %{address: host, port: port}, channels ->
key = build_address_key(host, port)
existing = Map.get(channels, key)

should_connect =
MapSet.member?(added, key) or
match?({:failed, _}, existing) or
not channel_alive?(existing)

if should_connect do
case existing do
{:connected, ch} -> do_disconnect(adapter, ch)
_ -> :ok
end

case connect_real_channel(state.virtual_channel, host, port, opts, adapter) do
{:ok, ch} -> Map.put(channels, key, {:connected, ch})
{:error, reason} -> Map.put(channels, key, {:failed, reason})
end
else
channels
end
end)
end

# Re-init load balancer with full updated address list.
#
# NOTE: We guard persistent_term writes to only happen when the picked
# channel actually changes. persistent_term updates trigger a global GC
# pass across all BEAM processes (see erlang.org/doc/apps/erts/persistent_term).
# With periodic re-resolution this function runs every 30s+ per connection,
# and on no-change cycles we must avoid redundant writes. A future
# improvement would be migrating to ETS with read_concurrency: true,
# which has no global GC cost on writes.
defp rebalance_after_reconcile(new_addresses, real_channels, state) do
if state.lb_mod do
case state.lb_mod.init(addresses: new_addresses) do
{:ok, new_lb_state} ->
{:ok, {host, port}, picked_lb_state} = state.lb_mod.pick(new_lb_state)
key = build_address_key(host, port)

case Map.get(real_channels, key) do
{:connected, picked_channel} ->
maybe_update_persistent_term(state.virtual_channel, picked_channel)

%{
state
| real_channels: real_channels,
lb_state: picked_lb_state,
virtual_channel: picked_channel
}

_ ->
fallback_to_healthy_channel(state, real_channels, picked_lb_state)
end

{:error, _} ->
fallback_to_healthy_channel(state, real_channels, state.lb_state)
end
else
fallback_to_healthy_channel(state, real_channels, state.lb_state)
end
end

defp fallback_to_healthy_channel(state, real_channels, lb_state) do
ref = state.virtual_channel.ref

case Enum.find_value(real_channels, fn {_k, v} -> match?({:connected, _}, v) && v end) do
{:connected, healthy_channel} ->
maybe_update_persistent_term(state.virtual_channel, healthy_channel)

%{
state
| real_channels: real_channels,
lb_state: lb_state,
virtual_channel: healthy_channel
}

nil ->
Logger.warning("No healthy channels available after re-resolution")
:persistent_term.erase({__MODULE__, :lb_state, ref})
%{state | real_channels: real_channels, lb_state: lb_state}
end
end

# Only write to persistent_term when the channel actually changed.
# persistent_term updates trigger a global GC pass, so we skip
# redundant writes on no-change re-resolution cycles.
defp maybe_update_persistent_term(current_channel, new_channel) do
if current_channel != new_channel do
:persistent_term.put(
{__MODULE__, :lb_state, new_channel.ref},
new_channel
)
end
end

defp channel_alive?({:connected, %{adapter_payload: %{conn_pid: pid}}}) when is_pid(pid) do
Process.alive?(pid)
end

defp channel_alive?({:connected, _}), do: true
defp channel_alive?(_), do: false

defp dns_target?(target) do
URI.parse(target).scheme == "dns"
end

defp via(ref) do
{:global, {__MODULE__, ref}}
end
Expand All @@ -333,7 +532,12 @@ defmodule GRPC.Client.Connection do
codec: GRPC.Codec.Proto,
compressor: nil,
accepted_compressors: [],
headers: []
headers: [],
lb_policy: nil,
resolver: GRPC.Client.Resolver,
resolve_interval: @default_resolve_interval,
max_resolve_interval: @default_max_resolve_interval,
min_resolve_interval: @default_min_resolve_interval
)

resolver = Keyword.get(opts, :resolver, GRPC.Client.Resolver)
Expand Down Expand Up @@ -364,7 +568,9 @@ defmodule GRPC.Client.Connection do
base_state = %__MODULE__{
virtual_channel: virtual_channel,
resolver: resolver,
adapter: adapter
adapter: adapter,
resolver_target: norm_target,
connect_opts: norm_opts
}

case resolver.resolve(norm_target) do
Expand Down Expand Up @@ -423,7 +629,7 @@ defmodule GRPC.Client.Connection do

key = build_address_key(prefer_host, prefer_port)

with {:ok, ch} <- Map.get(real_channels, key, {:error, :no_channel}) do
with {:connected, ch} <- Map.get(real_channels, key, {:failed, :no_channel}) do
{:ok,
%__MODULE__{
base_state
Expand All @@ -433,7 +639,7 @@ defmodule GRPC.Client.Connection do
real_channels: real_channels
}}
else
{:error, reason} -> {:error, reason}
{:failed, reason} -> {:error, reason}
end

{:error, :no_addresses} ->
Expand All @@ -451,7 +657,7 @@ defmodule GRPC.Client.Connection do
%__MODULE__{
base_state
| virtual_channel: ch,
real_channels: %{"#{host}:#{port}" => {:ok, ch}}
real_channels: %{"#{host}:#{port}" => {:connected, ch}}
}}

{:error, reason} ->
Expand All @@ -469,10 +675,10 @@ defmodule GRPC.Client.Connection do
adapter
) do
{:ok, ch} ->
{build_address_key(host, port), {:ok, ch}}
{build_address_key(host, port), {:connected, ch}}

{:error, reason} ->
{build_address_key(host, port), {:error, reason}}
{build_address_key(host, port), {:failed, reason}}
end
end)
end
Expand Down
Loading