Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/proxy/http/HttpSM.h
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,7 @@ class HttpSM : public Continuation, public PluginUserArgs<TS_USER_ARGS_TXN>
int server_connection_provided_cert = 0;
int64_t client_request_body_bytes = 0;
int64_t server_request_body_bytes = 0;
bool server_request_body_incomplete = false;
int64_t server_response_body_bytes = 0;
int64_t client_response_body_bytes = 0;
int64_t cache_response_body_bytes = 0;
Expand Down
29 changes: 27 additions & 2 deletions src/proxy/http/HttpSM.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2070,11 +2070,31 @@ HttpSM::state_read_server_response_header(int event, void *data)

// If there is a post body in transit, give up on it
if (tunnel.is_tunnel_alive()) {
// Record bytes already written to the server before aborting the tunnel.
// tunnel_handler_post_server() won't be called after abort, so we must
// capture this here for stats/logging purposes.
HttpTunnelConsumer *server_consumer = tunnel.get_consumer(server_txn);
if (server_consumer && server_request_body_bytes == 0) {
server_request_body_bytes = server_consumer->bytes_written;
}
// Mark the body transfer as incomplete so the origin connection is not
// pooled. The origin may have unconsumed body data in the TCP stream.
server_request_body_incomplete = true;
tunnel.abort_tunnel();
// Make sure client connection is closed when we are done in case there is cruft left over
t_state.client_info.keep_alive = HTTPKeepAlive::NO_KEEPALIVE;
// Similarly the server connection should also be closed
t_state.current.server->keep_alive = HTTPKeepAlive::NO_KEEPALIVE;
} else if (!server_request_body_incomplete && server_request_body_bytes > 0 &&
t_state.hdr_info.client_request.m_100_continue_sent) {
// When ATS proactively sent 100 Continue to the client
// (send_100_continue_response), the body tunnel was set up before the
// origin confirmed it would accept the body. The tunnel may have
// completed before the origin responded, but the origin might not have
// consumed the body data. Prevent connection pooling to avoid the next
// request on this connection seeing leftover body bytes as corruption.
server_request_body_incomplete = true;
t_state.current.server->keep_alive = HTTPKeepAlive::NO_KEEPALIVE;
}
}

Expand Down Expand Up @@ -3176,8 +3196,13 @@ HttpSM::tunnel_handler_server(int event, HttpTunnelProducer *p)

bool close_connection = false;

// Don't pool the connection if the request body transfer was incomplete.
// The origin may not have consumed all of it before sending this response,
// leaving unconsumed body data in the TCP stream that would corrupt the
// next request on this connection.
if (t_state.current.server->keep_alive == HTTPKeepAlive::KEEPALIVE && server_entry->eos == false &&
plugin_tunnel_type == HttpPluginTunnel_t::NONE && t_state.txn_conf->keep_alive_enabled_out == 1) {
plugin_tunnel_type == HttpPluginTunnel_t::NONE && t_state.txn_conf->keep_alive_enabled_out == 1 &&
!server_request_body_incomplete) {
close_connection = false;
} else {
if (t_state.current.server->keep_alive != HTTPKeepAlive::KEEPALIVE) {
Expand Down Expand Up @@ -6035,7 +6060,7 @@ HttpSM::release_server_session(bool serve_from_cache)
(t_state.hdr_info.server_response.status_get() == HTTPStatus::NOT_MODIFIED ||
(t_state.hdr_info.server_request.method_get_wksidx() == HTTP_WKSIDX_HEAD &&
t_state.www_auth_content != HttpTransact::CacheAuth_t::NONE)) &&
plugin_tunnel_type == HttpPluginTunnel_t::NONE && (!server_entry || !server_entry->eos)) {
plugin_tunnel_type == HttpPluginTunnel_t::NONE && (!server_entry || !server_entry->eos) && !server_request_body_incomplete) {
if (t_state.www_auth_content == HttpTransact::CacheAuth_t::NONE || serve_from_cache == false) {
// Must explicitly set the keep_alive_no_activity time before doing the release
server_txn->set_inactivity_timeout(HRTIME_SECONDS(t_state.txn_conf->keep_alive_no_activity_timeout_out));
Expand Down
100 changes: 100 additions & 0 deletions tests/gold_tests/post/corruption_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#!/usr/bin/env python3
"""Client that sends two requests on one TCP connection to reproduce
100-continue connection pool corruption."""

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from http_utils import wait_for_headers_complete, determine_outstanding_bytes_to_read, drain_socket

import argparse
import socket
import sys
import time


def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument('proxy_address')
parser.add_argument('proxy_port', type=int)
parser.add_argument('-s', '--server-hostname', dest='server_hostname', default='example.com')
args = parser.parse_args()

host = args.server_hostname
body_size = 103
body_data = b'X' * body_size

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((args.proxy_address, args.proxy_port))

with sock:
# Request 1: POST with Expect: 100-continue and a body.
request1 = (
f'POST /expect-100-corrupted HTTP/1.1\r\n'
f'Host: {host}\r\n'
f'Connection: keep-alive\r\n'
f'Content-Length: {body_size}\r\n'
f'Expect: 100-continue\r\n'
f'\r\n').encode()
sock.sendall(request1)

# Send the body after a short delay without waiting for 100-continue.
time.sleep(0.5)
sock.sendall(body_data)

# Drain the response (might be 100 + 301, or just 301).
resp1_data = wait_for_headers_complete(sock)

# If we got a 100 Continue, read past it to the real response.
if b'100' in resp1_data.split(b'\r\n')[0]:
after_100 = resp1_data.split(b'\r\n\r\n', 1)[1] if b'\r\n\r\n' in resp1_data else b''
if b'\r\n\r\n' not in after_100:
after_100 += wait_for_headers_complete(sock)
resp1_data = after_100

# Drain the response body.
try:
outstanding = determine_outstanding_bytes_to_read(resp1_data)
if outstanding > 0:
drain_socket(sock, resp1_data, outstanding)
except ValueError:
pass

# Let ATS pool the origin connection.
time.sleep(0.5)

# Request 2: plain GET on the same client connection.
request2 = (f'GET /second-request HTTP/1.1\r\n'
f'Host: {host}\r\n'
f'Connection: close\r\n'
f'\r\n').encode()
sock.sendall(request2)

resp2_data = wait_for_headers_complete(sock)
status_line = resp2_data.split(b'\r\n')[0]

if b'400' in status_line or b'corrupted' in resp2_data.lower():
print('Corruption detected: second request saw corrupted data', flush=True)
elif b'502' in status_line:
print('Corruption detected: ATS returned 502 (origin parse error)', flush=True)
else:
print('No corruption: second request completed normally', flush=True)

return 0


if __name__ == '__main__':
sys.exit(main())
155 changes: 155 additions & 0 deletions tests/gold_tests/post/corruption_origin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
#!/usr/bin/env python3
"""Origin that sends a 301 without consuming the request body, then checks
whether a reused connection carries leftover (corrupted) data. Handles
multiple connections so that a fixed ATS can open a fresh one for the
second request."""

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import socket
import sys
import threading
import time

VALID_METHODS = {'GET', 'POST', 'PUT', 'DELETE', 'HEAD', 'OPTIONS', 'PATCH'}


def read_until_headers_complete(conn: socket.socket) -> bytes:
data = b''
while b'\r\n\r\n' not in data:
chunk = conn.recv(4096)
if not chunk:
return data
data += chunk
return data


def is_valid_http_request_line(line: str) -> bool:
parts = line.strip().split(' ')
if len(parts) < 3:
return False
return parts[0] in VALID_METHODS and parts[-1].startswith('HTTP/')


def send_200(conn: socket.socket) -> None:
ok_body = b'OK'
conn.sendall(b'HTTP/1.1 200 OK\r\n'
b'Content-Length: ' + str(len(ok_body)).encode() + b'\r\n'
b'\r\n' + ok_body)


def handle_connection(conn: socket.socket, args: argparse.Namespace, result: dict) -> None:
try:
data = read_until_headers_complete(conn)
if not data:
# Readiness probe.
conn.close()
return

first_line = data.split(b'\r\n')[0].decode('utf-8', errors='replace')

if first_line.startswith('POST'):
# First request: send 301 without consuming the body.
time.sleep(args.delay)

body = b'Redirecting'
response = (
b'HTTP/1.1 301 Moved Permanently\r\n'
b'Location: http://example.com/\r\n'
b'Connection: keep-alive\r\n'
b'Content-Length: ' + str(len(body)).encode() + b'\r\n'
b'\r\n' + body)
conn.sendall(response)

# Wait for potential reuse on this connection.
conn.settimeout(args.timeout)
try:
second_data = b''
while b'\r\n' not in second_data:
chunk = conn.recv(4096)
if not chunk:
break
second_data += chunk

if second_data:
second_line = second_data.split(b'\r\n')[0].decode('utf-8', errors='replace')
if is_valid_http_request_line(second_line):
send_200(conn)
else:
result['corrupted'] = True
err_body = b'corrupted'
conn.sendall(
b'HTTP/1.1 400 Bad Request\r\n'
b'Content-Length: ' + str(len(err_body)).encode() + b'\r\n'
b'\r\n' + err_body)
except socket.timeout:
pass

elif first_line.startswith('GET'):
# Second request on a new connection (fix is working).
result['new_connection'] = True
send_200(conn)

conn.close()
except Exception:
try:
conn.close()
except Exception:
pass


def main() -> int:
parser = argparse.ArgumentParser()
parser.add_argument('port', type=int)
parser.add_argument('--delay', type=float, default=1.0)
parser.add_argument('--timeout', type=float, default=5.0)
args = parser.parse_args()

sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(('', args.port))
sock.listen(5)
sock.settimeout(args.timeout + 5)

result = {'corrupted': False, 'new_connection': False}
threads = []
connections_handled = 0

try:
while connections_handled < 10:
try:
conn, _ = sock.accept()
t = threading.Thread(target=handle_connection, args=(conn, args, result))
t.daemon = True
t.start()
threads.append(t)
connections_handled += 1
except socket.timeout:
break
except Exception:
pass

for t in threads:
t.join(timeout=args.timeout + 2)

sock.close()
return 0


if __name__ == '__main__':
sys.exit(main())
67 changes: 67 additions & 0 deletions tests/gold_tests/post/expect-100-continue-corruption.test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys

Test.Summary = '''
Verify that when an origin responds before consuming the request body on a
connection with Expect: 100-continue, ATS does not return the origin connection
to the pool with unconsumed data.
'''

tr = Test.AddTestRun('Verify 100-continue with early origin response does not corrupt pooled connections.')

# DNS.
dns = tr.MakeDNServer('dns', default='127.0.0.1')

# Origin.
Test.GetTcpPort('origin_port')
tr.Setup.CopyAs('corruption_origin.py')
origin = tr.Processes.Process(
'origin', f'{sys.executable} corruption_origin.py '
f'{Test.Variables.origin_port} --delay 1.0 --timeout 5.0')
origin.Ready = When.PortOpen(Test.Variables.origin_port)

# ATS.
ts = tr.MakeATSProcess('ts', enable_cache=False)
ts.Disk.remap_config.AddLine(f'map / http://backend.example.com:{Test.Variables.origin_port}')
ts.Disk.records_config.update(
{
'proxy.config.diags.debug.enabled': 1,
'proxy.config.diags.debug.tags': 'http',
'proxy.config.dns.nameservers': f'127.0.0.1:{dns.Variables.Port}',
'proxy.config.dns.resolv_conf': 'NULL',
'proxy.config.http.send_100_continue_response': 1,
})

# Client.
tr.Setup.CopyAs('corruption_client.py')
tr.Setup.CopyAs('http_utils.py')
tr.Processes.Default.Command = (
f'{sys.executable} corruption_client.py '
f'127.0.0.1 {ts.Variables.port} '
f'-s backend.example.com')
tr.Processes.Default.ReturnCode = 0
tr.Processes.Default.StartBefore(dns)
tr.Processes.Default.StartBefore(origin)
tr.Processes.Default.StartBefore(ts)

# With the fix, ATS should not pool the origin connection when the
# request body was not fully consumed, preventing corruption.
tr.Processes.Default.Streams.stdout += Testers.ContainsExpression(
'No corruption', 'The second request should complete normally because ATS '
'does not pool origin connections with unconsumed body data.')
tr.Processes.Default.Streams.stdout += Testers.ExcludesExpression('Corruption detected', 'No corruption should be detected.')