diff --git a/lib/ch/connection.ex b/lib/ch/connection.ex index b53394a..da7201c 100644 --- a/lib/ch/connection.ex +++ b/lib/ch/connection.ex @@ -7,6 +7,8 @@ defmodule Ch.Connection do @user_agent "ch/" <> Mix.Project.config()[:version] + @server_display_name_key :server_display_name + @typep conn :: HTTP.t() @impl true @@ -25,6 +27,12 @@ defmodule Ch.Connection do |> maybe_put_private(:username, opts[:username]) |> maybe_put_private(:password, opts[:password]) |> maybe_put_private(:settings, opts[:settings]) + |> HTTP.put_private(:reconnect_opts, %{ + scheme: scheme, + address: address, + port: port, + mint_opts: mint_opts + }) handshake = Query.build("select 1, version()") params = DBConnection.Query.encode(handshake, _params = [], _opts = []) @@ -364,9 +372,11 @@ defmodule Ch.Connection do | {:error, Error.t(), conn} | {:disconnect, Mint.Types.error(), conn} defp request(conn, method, path, headers, body, opts) do - with {:ok, conn, _ref} <- send_request(conn, method, path, headers, body) do - receive_full_response(conn, timeout(conn, opts)) - end + with_retry_if_stale_connection(conn, fn conn -> + with {:ok, conn, _ref} <- send_request(conn, method, path, headers, body) do + receive_full_response(conn, timeout(conn, opts)) + end + end) end @spec request_chunked(conn, binary, binary, Mint.Types.headers(), Enumerable.t(), Keyword.t()) :: @@ -374,9 +384,11 @@ defmodule Ch.Connection do | {:error, Error.t(), conn} | {:disconnect, Mint.Types.error(), conn} def request_chunked(conn, method, path, headers, stream, opts) do - with {:ok, conn, ref} <- send_request(conn, method, path, headers, :stream), - {:ok, conn} <- stream_body(conn, ref, stream), - do: receive_full_response(conn, timeout(conn, opts)) + with_retry_if_stale_connection(conn, fn conn -> + with {:ok, conn, ref} <- send_request(conn, method, path, headers, :stream), + {:ok, conn} <- stream_body(conn, ref, stream), + do: receive_full_response(conn, timeout(conn, opts)) + end) end @spec stream_body(conn, Mint.Types.request_ref(), Enumerable.t()) :: @@ -405,6 +417,56 @@ defmodule Ch.Connection do end end + defp with_retry_if_stale_connection(conn, fun) do + case fun.(conn) do + {:disconnect, reason, conn} -> + if reconnectable_error?(reason) do + case reconnect(conn) do + {:ok, new_conn} -> + fun.(new_conn) + + {:error, reason} -> + {:disconnect, reason, conn} + end + else + {:disconnect, reason, conn} + end + + other -> + other + end + end + + defp reconnectable_error?(%Mint.TransportError{reason: :closed}), do: true + defp reconnectable_error?(%Mint.TransportError{reason: :econnreset}), do: true + defp reconnectable_error?(_), do: false + + @spec reconnect(conn) :: {:ok, conn} | {:error, Mint.Types.error()} + defp reconnect(conn) do + %{scheme: scheme, address: address, port: port, mint_opts: mint_opts} = + HTTP.get_private(conn, :reconnect_opts) + + {:ok, _closed_conn} = HTTP.close(conn) + + case HTTP.connect(scheme, address, port, mint_opts) do + {:ok, new_conn} -> + new_conn = + new_conn + |> HTTP.put_private(:timeout, HTTP.get_private(conn, :timeout)) + |> maybe_put_private(:database, HTTP.get_private(conn, :database)) + |> maybe_put_private(:username, HTTP.get_private(conn, :username)) + |> maybe_put_private(:password, HTTP.get_private(conn, :password)) + |> maybe_put_private(:settings, HTTP.get_private(conn, :settings)) + |> HTTP.put_private(:reconnect_opts, HTTP.get_private(conn, :reconnect_opts)) + |> maybe_put_private(@server_display_name_key, HTTP.get_private(conn, @server_display_name_key)) + + {:ok, new_conn} + + {:error, _reason} = error -> + error + end + end + @spec receive_full_response(conn, timeout) :: {:ok, conn, [response]} | {:error, Error.t(), conn} @@ -499,8 +561,6 @@ defmodule Ch.Connection do "/?" <> URI.encode_query(settings ++ query_params) end - @server_display_name_key :server_display_name - @spec ensure_same_server(conn, Mint.Types.headers()) :: conn defp ensure_same_server(conn, headers) do expected_name = HTTP.get_private(conn, @server_display_name_key)