diff --git a/awscli/customizations/ec2instanceconnect/websocket.py b/awscli/customizations/ec2instanceconnect/websocket.py index 68b1828c0262..0aaf4991225b 100644 --- a/awscli/customizations/ec2instanceconnect/websocket.py +++ b/awscli/customizations/ec2instanceconnect/websocket.py @@ -76,7 +76,16 @@ def has_data_to_read(self): return False def read(self, amt) -> bytes: - return sys.stdin.buffer.read1(amt) + try: + data = sys.stdin.buffer.read1(amt) + # Empty data indicates EOF (pipe closed) + if not data: + logger.debug("Stdin returned empty data (EOF). Input is closed.") + raise InputClosedError() + return data + except (OSError, IOError) as e: + logger.debug(f"IO error reading from stdin: {str(e)}") + raise InputClosedError() def write(self, data): sys.stdout.buffer.write(data) @@ -88,38 +97,70 @@ def close(self): class WindowsStdinStdoutIO(StdinStdoutIO): def has_data_to_read(self): - return True + # For Windows, we can't reliably check stdin without blocking + # We'll rely on the read method to detect when input is closed + # by catching EOF errors in the calling code + try: + if sys.stdin.closed: + return False + return True + except (OSError, ValueError, IOError): + return False class TCPSocketIO(BaseWebsocketIO): def __init__(self, conn): self.conn = conn + self._is_closed = False def has_data_to_read(self): - return True + if self._is_closed: + return False + + # Use select with a timeout to check if there's data + try: + read_ready, _, _ = select.select([self.conn], [], [], _SELECT_TIMEOUT) + return bool(read_ready) + except (OSError, ValueError, socket.error): + self._is_closed = True + return False def read(self, amt) -> bytes: - data = self.conn.recv(amt) - # In listener mode use can CTRL+C during host verification that kills the client TCP connect, - # when this happens we are able to successfully disconnect because has_data_to_read always return true. - # This will check if data is empty and if yes then raise InputCloseError - # - # recv() relies on the underlying system call which returns empty bytes when the connection is closed. - # Linux: https://manpages.debian.org/bullseye/manpages-dev/recv.2.en.html - # Windows: https://learn.microsoft.com/en-us/windows/win32/api/winsock/nf-winsock-recv - if not data: + try: + data = self.conn.recv(amt) + # In listener mode use can CTRL+C during host verification that kills the client TCP connect, + # when this happens we are able to successfully disconnect because has_data_to_read always return true. + # This will check if data is empty and if yes then raise InputCloseError + # + # recv() relies on the underlying system call which returns empty bytes when the connection is closed. + # Linux: https://manpages.debian.org/bullseye/manpages-dev/recv.2.en.html + # Windows: https://learn.microsoft.com/en-us/windows/win32/api/winsock/nf-winsock-recv + if not data: + self._is_closed = True + raise InputClosedError() + return data + except (OSError, socket.error): + self._is_closed = True raise InputClosedError() - return data def write(self, data): - self.conn.sendall(data) + if self._is_closed: + raise InputClosedError() + try: + self.conn.sendall(data) + except (OSError, socket.error): + self._is_closed = True + raise InputClosedError() def close(self): - try: - self.conn.close() - # On Windows, we could receive an OSError if the tcp conn is already closed. - except OSError: - pass + if not self._is_closed: + self._is_closed = True + try: + self.conn.shutdown(socket.SHUT_RDWR) + self.conn.close() + # On Windows, we could receive an OSError if the tcp conn is already closed. + except OSError: + pass class Websocket: @@ -217,9 +258,25 @@ def write_data_from_input(self): try: # Start writing data to the websocket connection and block current thread. self._write_data_from_input() + except Exception as e: + logger.error(f"Unexpected error in write_data_from_input: {str(e)}") finally: + # Make sure to clean up on exit + logger.debug("Exiting write_data_from_input, cleaning up") self.close() + # If we're a stdin/stdout websocket and input was closed, + # ensure the process exits cleanly + if isinstance(self.websocketio, StdinStdoutIO) or isinstance(self.websocketio, WindowsStdinStdoutIO): + logger.debug("Stdin/stdout websocket closed, exiting process") + # This is a bit drastic but necessary to ensure the process exits + # when stdin is closed in pipe mode + import os + import signal + # Send SIGTERM to ourselves to initiate clean shutdown + # This is more reliable than sys.exit() which can be caught + os.kill(os.getpid(), signal.SIGTERM) + if self._exception: raise self._exception @@ -231,16 +288,36 @@ def close(self): def _write_data_from_input(self): while not self._shutdown_event.is_set(): + # Check if websocket is still valid + if not self._websocket: + logger.debug('Websocket is closed or invalid. Exiting write loop.') + self.close() + return + # Wait until there's some data to read - if not self.websocketio.has_data_to_read(): - time.sleep(self._WAIT_INTERVAL_FOR_INPUT) - continue + try: + if not self.websocketio.has_data_to_read(): + time.sleep(self._WAIT_INTERVAL_FOR_INPUT) + continue + except Exception as e: + logger.debug(f'Error checking for data: {str(e)}. Shutting down websocket.') + self.close() + return try: data = self.websocketio.read(self._MAX_BYTES_PER_FRAME) + # Skip empty data (shouldn't happen, but as a safeguard) + if not data: + logger.debug('Received empty data. Skipping frame.') + continue except InputClosedError: logger.debug('Input closed. Shutting down websocket.') self.close() + return + except Exception as e: + logger.debug(f'Error reading data: {str(e)}. Shutting down websocket.') + self.close() + return try: self._websocket.send_frame( @@ -248,8 +325,15 @@ def _write_data_from_input(self): payload=data, on_complete=self._on_send_frame_complete_data, ) - # Block until send_frame on_complete - self._send_frame_results_queue.get() + # Block until send_frame on_complete with a timeout + try: + result = self._send_frame_results_queue.get(timeout=5.0) + if result and hasattr(result, 'exception') and result.exception: + raise result.exception + except Exception as e: + logger.debug(f'Timeout or error waiting for frame completion: {str(e)}') + self.close() + return except RuntimeError as e: crt_exceptions = [ "AWS_ERROR_HTTP_WEBSOCKET_CLOSE_FRAME_SENT", @@ -261,8 +345,15 @@ def _write_data_from_input(self): f"Received exception when sending websocket frame: {e.args}" ) self.close() + return else: + logger.debug(f"Unhandled runtime error: {e.args}") + self.close() raise e + except Exception as e: + logger.debug(f'Unexpected error sending frame: {str(e)}') + self.close() + return def _on_connection(self, data: OnConnectionSetupData) -> None: request_id_header = [ @@ -354,9 +445,21 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): - for _, web_socket in self._inflight_futures_and_websockets: - # Close the websocket handlers. - web_socket.close() + logger.debug("Shutting down WebsocketManager") + # First set RUNNING flag to false so any remaining loops exit + self.RUNNING.set() + + # Close all websocket handlers + for future, web_socket in self._inflight_futures_and_websockets: + try: + web_socket.close() + # Try to cancel any still-running futures + if not future.done(): + future.cancel() + except Exception as e: + logger.debug(f"Error closing websocket: {str(e)}") + + # Close server socket if exists if self._socket: try: self._socket.shutdown(socket.SHUT_RDWR) @@ -364,7 +467,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): # On Windows, if the socket is already closed, we will get an OSError. except OSError: pass - self._executor.shutdown() + + # Shutdown executor with a timeout + logger.debug("Shutting down executor") + self._executor.shutdown(wait=False) + logger.debug("WebsocketManager shutdown complete") # Used to break out of while loop in tests. RUNNING = threading.Event() @@ -375,11 +482,20 @@ def run(self): websocketio = ( WindowsStdinStdoutIO() if is_windows else StdinStdoutIO() ) - future = self._open_websocket_connection( - Websocket(websocketio, websocket_id=None) - ) - # Block until the future completes. - future.result() + web_socket = Websocket(websocketio, websocket_id=None) + try: + future = self._open_websocket_connection(web_socket) + # Block until the future completes. + future.result() + except WebsocketException as e: + logger.error(f"Websocket error: {str(e)}") + except Exception as e: + logger.error(f"Unexpected error: {str(e)}") + finally: + # Make sure everything is closed and we can exit + web_socket.close() + # Force shutdown the executor to ensure the process can exit + self._executor.shutdown(wait=False) else: self._listen_on_port() @@ -424,13 +540,21 @@ def _listen_on_port(self): ) def _open_websocket_connection(self, web_socket): - presigned_url = self._eice_request_signer.get_presigned_url() - web_socket.connect(presigned_url, self._user_agent) + try: + presigned_url = self._eice_request_signer.get_presigned_url() + web_socket.connect(presigned_url, self._user_agent) - future = self._executor.submit(web_socket.write_data_from_input) + # Submit the task with a done callback to clean up resources + future = self._executor.submit(web_socket.write_data_from_input) - self._inflight_futures_and_websockets.append((future, web_socket)) - return future + # Store for cleanup + self._inflight_futures_and_websockets.append((future, web_socket)) + + return future + except Exception as e: + logger.error(f"Failed to open websocket connection: {str(e)}") + web_socket.close() + raise def _print_tcp_conn_closed(self, web_socket): def _on_done_callback(future):