diff --git a/.gitignore b/.gitignore index 896e771..c750ad5 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,6 @@ testconfig/ /rnsh.egg-info/ /build/ /dist/ -.pytest_cache/ \ No newline at end of file +.pytest_cache/ +*__pycache__ +/RNS diff --git a/pyproject.toml b/pyproject.toml index cb900ec..08ea7a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "rnsh" -version = "0.1.1" +version = "0.1.2" description = "Shell over Reticulum" authors = ["acehoss "] license = "MIT" @@ -9,7 +9,7 @@ readme = "README.md" [tool.poetry.dependencies] python = "^3.7" docopt = "^0.6.2" -rns = "^0.5.3" +rns = ">=0.5.9" # rns = { git = "https://github.com/acehoss/Reticulum.git", branch = "feature/channel" } # rns = { path = "../Reticulum/", develop = true } tomli = "^2.0.1" diff --git a/rnsh/args.py b/rnsh/args.py index 9e8e283..3a543f7 100644 --- a/rnsh/args.py +++ b/rnsh/args.py @@ -20,7 +20,7 @@ def _split_array_at(arr: [_T], at: _T) -> ([_T], [_T]): Usage: rnsh -l [-c ] [-i | -s ] [-v... | -q...] -p rnsh -l [-c ] [-i | -s ] [-v... | -q...] - [-b ] (-n | -a [-a ] ...) [-A | -C] + [-b ] [-n] [-a ] ([-a ] ...) [-A | -C] [[--] [ ...]] rnsh [-c ] [-i ] [-v... | -q...] -p rnsh [-c ] [-i ] [-v... | -q...] [-N] [-m] [-w ] @@ -40,7 +40,9 @@ def _split_array_at(arr: [_T], at: _T) -> ([_T], [_T]): user rnsh is running under will be used. -b --announce PERIOD Announce on startup and every PERIOD seconds Specify 0 for PERIOD to announce on startup only. - -a HASH --allowed HASH Specify identities allowed to connect + -a HASH --allowed HASH Specify identities allowed to connect. Allowed identities + can also be specified in ~/.rnsh/allowed_identities or + ~/.config/rnsh/allowed_identities, one hash per line. -n --no-auth Disable authentication -N --no-id Disable identify on connect -A --remote-command-as-args Concatenate remote command to argument list of /shell diff --git a/rnsh/initiator.py b/rnsh/initiator.py index 9f62ab8..085a8af 100644 --- a/rnsh/initiator.py +++ b/rnsh/initiator.py @@ -49,6 +49,7 @@ import contextlib import rnsh.args import pwd +import bz2 import rnsh.protocol as protocol import rnsh.helpers as helpers import rnsh.rnsh @@ -230,6 +231,7 @@ async def initiate(configdir: str, identitypath: str, verbosity: int, quietness: loop = asyncio.get_running_loop() state = InitiatorState.IS_INITIAL data_buffer = bytearray(sys.stdin.buffer.read()) if not os.isatty(sys.stdin.fileno()) else bytearray() + line_buffer = bytearray() await _initiate_link( configdir=configdir, @@ -273,14 +275,81 @@ def sigwinch_handler(): # log.debug("WindowChanged") winch = True + esc = False + pre_esc = True + line_mode = False + line_flush = False + blind_write_count = 0 + flush_chars = ["\x01", "\x03", "\x04", "\x05", "\x0c", "\x11", "\x13", "\x15", "\x19", "\t", "\x1A", "\x1B"] + def handle_escape(b): + nonlocal line_mode + if b == "~": + return "~" + elif b == "?": + os.write(1, "\n\r\n\rSupported rnsh escape sequences:".encode("utf-8")) + os.write(1, "\n\r ~~ Send the escape character by typing it twice".encode("utf-8")) + os.write(1, "\n\r ~. Terminate session and exit immediately".encode("utf-8")) + os.write(1, "\n\r ~L Toggle line-interactive mode".encode("utf-8")) + os.write(1, "\n\r ~? Display this quick reference\n\r".encode("utf-8")) + os.write(1, "\n\r(Escape sequences are only recognized immediately after newline)\n\r".encode("utf-8")) + elif b == ".": + _link.teardown() + elif b == "L": + line_mode = not line_mode + if line_mode: + os.write(1, "\n\rLine-interactive mode enabled\n\r".encode("utf-8")) + else: + os.write(1, "\n\rLine-interactive mode disabled\n\r".encode("utf-8")) + + return None + stdin_eof = False def stdin(): - nonlocal stdin_eof + nonlocal stdin_eof, pre_esc, esc, line_mode + nonlocal line_flush, blind_write_count try: - data = process.tty_read(sys.stdin.fileno()) - log.debug(f"stdin {data}") - if data is not None: - data_buffer.extend(data) + in_data = process.tty_read(sys.stdin.fileno()) + if in_data is not None: + data = bytearray() + for b in bytes(in_data): + c = chr(b) + if c == "\r": + pre_esc = True + line_flush = True + data.append(b) + elif line_mode and c in flush_chars: + line_flush = True + data.append(b) + elif line_mode and (c == "\b" or c == "\x7f"): + if len(line_buffer)>0: + line_buffer.pop(-1) + blind_write_count -= 1 + os.write(1, "\b \b".encode("utf-8")) + elif pre_esc == True and c == "~": + pre_esc = False + esc = True + elif esc == True: + ret = handle_escape(c) + if ret != None: + data.append(ord(ret)) + esc = False + else: + data.append(b) + + if not line_mode: + data_buffer.extend(data) + else: + line_buffer.extend(data) + if line_flush: + data_buffer.extend(line_buffer) + line_buffer.clear() + os.write(1, ("\b \b"*blind_write_count).encode("utf-8")) + line_flush = False + blind_write_count = 0 + else: + os.write(1, data) + blind_write_count += len(data) + except EOFError: if os.isatty(0): data_buffer.extend(process.CTRL_D) @@ -362,11 +431,42 @@ def stdin(): processed = False if channel.is_ready_to_send(): - stdin = data_buffer[:mdu] - data_buffer = data_buffer[mdu:] + def compress_adaptive(buf: bytes): + comp_tries = RNS.RawChannelWriter.COMPRESSION_TRIES + comp_try = 1 + comp_success = False + + chunk_len = len(buf) + if chunk_len > RNS.RawChannelWriter.MAX_CHUNK_LEN: + chunk_len = RNS.RawChannelWriter.MAX_CHUNK_LEN + chunk_segment = None + + chunk_segment = None + while chunk_len > 32 and comp_try < comp_tries: + chunk_segment_length = int(chunk_len/comp_try) + compressed_chunk = bz2.compress(buf[:chunk_segment_length]) + compressed_length = len(compressed_chunk) + if compressed_length < protocol.StreamDataMessage.MAX_DATA_LEN and compressed_length < chunk_segment_length: + comp_success = True + break + else: + comp_try += 1 + + if comp_success: + chunk = compressed_chunk + processed_length = chunk_segment_length + else: + chunk = bytes(buf[:protocol.StreamDataMessage.MAX_DATA_LEN]) + processed_length = len(chunk) + + return comp_success, processed_length, chunk + + comp_success, processed_length, chunk = compress_adaptive(data_buffer) + stdin = chunk + data_buffer = data_buffer[processed_length:] eof = not sent_eof and stdin_eof and len(stdin) == 0 if len(stdin) > 0 or eof: - channel.send(protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDIN, stdin, eof)) + channel.send(protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDIN, stdin, eof, comp_success)) sent_eof = eof processed = True diff --git a/rnsh/listener.py b/rnsh/listener.py index d3e617f..1a7d231 100644 --- a/rnsh/listener.py +++ b/rnsh/listener.py @@ -64,7 +64,9 @@ def _get_logger(name: str): _identity = None _reticulum = None _allow_all = False +_allowed_file = None _allowed_identity_hashes = [] +_allowed_file_identity_hashes = [] _cmd: [str] | None = None DATA_AVAIL_MSG = "data available" _finished: asyncio.Event = None @@ -88,12 +90,37 @@ def _sigint_handler(sig, loop): else: raise KeyboardInterrupt() +def _reload_allowed_file(): + global _allowed_file, _allowed_file_identity_hashes + log = _get_logger("_listen") + if _allowed_file != None: + try: + with open(_allowed_file, "r") as file: + dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2 + added = 0 + line = 0 + _allowed_file_identity_hashes = [] + for allow in file.read().replace("\r", "").split("\n"): + line += 1 + if len(allow) == dest_len: + try: + destination_hash = bytes.fromhex(allow) + _allowed_file_identity_hashes.append(destination_hash) + added += 1 + except Exception: + log.debug(f"Discarded invalid Identity hash in {_allowed_file} at line {line}") + + ms = "y" if added == 1 else "ies" + log.debug(f"Loaded {added} allowed identit{ms} from "+str(_allowed_file)) + except Exception as e: + log.error(f"Error while reloading allowed indetities file: {e}") + async def listen(configdir, command, identitypath=None, service_name=None, verbosity=0, quietness=0, allowed=None, - disable_auth=None, announce_period=900, no_remote_command=True, remote_cmd_as_args=False, + allowed_file=None, disable_auth=None, announce_period=900, no_remote_command=True, remote_cmd_as_args=False, loop: asyncio.AbstractEventLoop = None): - global _identity, _allow_all, _allowed_identity_hashes, _reticulum, _cmd, _destination, _no_remote_command - global _remote_cmd_as_args, _finished + global _identity, _allow_all, _allowed_identity_hashes, _allowed_file, _allowed_file_identity_hashes + global _reticulum, _cmd, _destination, _no_remote_command, _remote_cmd_as_args, _finished log = _get_logger("_listen") if not loop: loop = asyncio.get_running_loop() @@ -135,6 +162,10 @@ async def listen(configdir, command, identitypath=None, service_name=None, verbo _allow_all = True session.ListenerSession.allow_all = True else: + if allowed_file is not None: + _allowed_file = allowed_file + _reload_allowed_file() + if allowed is not None: for a in allowed: try: @@ -154,10 +185,12 @@ async def listen(configdir, command, identitypath=None, service_name=None, verbo log.error(str(e)) exit(1) - if len(_allowed_identity_hashes) < 1 and not disable_auth: + if (len(_allowed_identity_hashes) < 1 and len(_allowed_file_identity_hashes) < 1) and not disable_auth: log.warning("Warning: No allowed identities configured, rnsh will not accept any connections!") def link_established(lnk: RNS.Link): + _reload_allowed_file() + session.ListenerSession.allowed_file_identity_hashes = _allowed_file_identity_hashes session.ListenerSession(session.RNSOutlet.get_outlet(lnk), lnk.get_channel(), loop) _destination.set_link_established_callback(link_established) diff --git a/rnsh/process.py b/rnsh/process.py index cecf512..ffbee79 100644 --- a/rnsh/process.py +++ b/rnsh/process.py @@ -525,7 +525,6 @@ def write(self, data: bytes): Write bytes to the stdin of the child process. :param data: bytes to write """ - self._log.debug(f"write({data})") os.write(self._child_stdin, data) def set_winsize(self, r: int, c: int, h: int, v: int): diff --git a/rnsh/rnsh.py b/rnsh/rnsh.py index b4f41d8..c1e8486 100644 --- a/rnsh/rnsh.py +++ b/rnsh/rnsh.py @@ -117,7 +117,13 @@ async def _rnsh_cli_main(): return 0 if args.listen: - # log.info("command " + args.command) + allowed_file = None + dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH//8)*2 + if os.path.isfile(os.path.expanduser("~/.config/rnsh/allowed_identities")): + allowed_file = os.path.expanduser("~/.config/rnsh/allowed_identities") + elif os.path.isfile(os.path.expanduser("~/.rnsh/allowed_identities")): + allowed_file = os.path.expanduser("~/.rnsh/allowed_identities") + await listener.listen(configdir=args.config, command=args.command_line, identitypath=args.identity, @@ -125,6 +131,7 @@ async def _rnsh_cli_main(): verbosity=args.verbose, quietness=args.quiet, allowed=args.allowed, + allowed_file=allowed_file, disable_auth=args.no_auth, announce_period=args.announce, no_remote_command=args.no_remote_cmd, diff --git a/rnsh/session.py b/rnsh/session.py index 1c948e7..1b4b906 100644 --- a/rnsh/session.py +++ b/rnsh/session.py @@ -12,6 +12,7 @@ from abc import abstractmethod, ABC from multiprocessing import Manager import os +import bz2 import RNS import logging as __logging @@ -68,6 +69,7 @@ def teardown(self): class ListenerSession: sessions: List[ListenerSession] = [] allowed_identity_hashes: [any] = [] + allowed_file_identity_hashes: [any] = [] allow_all: bool = False allow_remote_command: bool = False default_command: [str] = [] @@ -182,7 +184,7 @@ def _initiator_identified(self, outlet, identity): if self.state not in [LSState.LSSTATE_WAIT_IDENT, LSState.LSSTATE_WAIT_VERS]: self._protocol_error(LSState.LSSTATE_WAIT_IDENT.name) - if not self.allow_all and identity.hash not in self.allowed_identity_hashes: + if not self.allow_all and identity.hash not in self.allowed_identity_hashes and identity.hash not in self.allowed_file_identity_hashes: self.terminate("Identity is not allowed.") self.remote_identity = identity @@ -204,31 +206,59 @@ async def terminate_all(cls, reason: str): await asyncio.sleep(0) def pump(self) -> bool: + def compress_adaptive(buf: bytes): + comp_tries = RNS.RawChannelWriter.COMPRESSION_TRIES + comp_try = 1 + comp_success = False + + chunk_len = len(buf) + if chunk_len > RNS.RawChannelWriter.MAX_CHUNK_LEN: + chunk_len = RNS.RawChannelWriter.MAX_CHUNK_LEN + chunk_segment = None + + chunk_segment = None + while chunk_len > 32 and comp_try < comp_tries: + chunk_segment_length = int(chunk_len/comp_try) + compressed_chunk = bz2.compress(buf[:chunk_segment_length]) + compressed_length = len(compressed_chunk) + if compressed_length < protocol.StreamDataMessage.MAX_DATA_LEN and compressed_length < chunk_segment_length: + comp_success = True + break + else: + comp_try += 1 + + if comp_success: + chunk = compressed_chunk + processed_length = chunk_segment_length + else: + chunk = bytes(buf[:protocol.StreamDataMessage.MAX_DATA_LEN]) + processed_length = len(chunk) + + return comp_success, processed_length, chunk + try: if self.state != LSState.LSSTATE_RUNNING: return False elif not self.channel.is_ready_to_send(): return False elif len(self.stderr_buf) > 0: - mdu = protocol.StreamDataMessage.MAX_DATA_LEN - data = self.stderr_buf[:mdu] - self.stderr_buf = self.stderr_buf[mdu:] + comp_success, processed_length, data = compress_adaptive(self.stderr_buf) + self.stderr_buf = self.stderr_buf[processed_length:] send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent self.stderr_eof_sent = self.stderr_eof_sent or send_eof msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDERR, - data, send_eof) + data, send_eof, comp_success) self.send(msg) if send_eof: self.stderr_eof_sent = True return True elif len(self.stdout_buf) > 0: - mdu = protocol.StreamDataMessage.MAX_DATA_LEN - data = self.stdout_buf[:mdu] - self.stdout_buf = self.stdout_buf[mdu:] + comp_success, processed_length, data = compress_adaptive(self.stdout_buf) + self.stdout_buf = self.stdout_buf[processed_length:] send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent self.stdout_eof_sent = self.stdout_eof_sent or send_eof msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDOUT, - data, send_eof) + data, send_eof, comp_success) self.send(msg) if send_eof: self.stdout_eof_sent = True