diff --git a/AGENTS.md b/AGENTS.md index f37f7d6..dd2ab0c 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -30,10 +30,12 @@ The first implementation pass intentionally narrows the MVP: * .NET 10 stdio MCP server * `ssh_exec` * `terminal_start`, `terminal_write`, `terminal_read`, and `terminal_stop` +* `sftp_list`, `sftp_get`, and `sftp_put` * SSH.NET-based SSH connections * Tool-supplied `host`, `username`, optional `port`, optional `keyPath`, and optional `keyPassphrase` * Default key discovery from `~/.ssh/id_ed25519`, `~/.ssh/id_ecdsa`, then `~/.ssh/id_rsa` when `keyPath` is omitted -* No SSH agent, OpenSSH alias, `ssh -G`, ProxyJump, ProxyCommand, or SFTP support yet +* SFTP first for file transfer; `sftp_get` and `sftp_put` silently fall back to SCP if SFTP is unavailable +* No SSH agent, OpenSSH alias, `ssh -G`, ProxyJump, or ProxyCommand support yet * Basic audit logging and timeout enforcement ## Non-Goals @@ -169,21 +171,21 @@ Priority order: ```text 1. SFTP -2. SCP fallback (optional future enhancement) +2. SCP fallback for get/put ``` The implementation must: -* Validate SFTP subsystem availability during connection -* Return structured capability errors if unavailable -* Not silently downgrade to SCP +* Use SFTP for directory listing +* Use SFTP first for file get/put +* Silently fall back to SCP for file get/put if SFTP is unavailable Example error: ```json { "error": "sftp_unavailable", - "message": "Remote host does not expose the SFTP subsystem.", + "message": "Remote host does not expose the SFTP subsystem for directory listing.", "scpFallbackAvailable": false } ``` @@ -243,7 +245,6 @@ Input: { "host": "prod-api.example.com", "username": "deploy", - "shell": "bash", "cols": 120, "rows": 40, "port": 22, @@ -267,6 +268,7 @@ Requirements: * Maintain server-side session state * Support idle timeout cleanup * Use the same key-auth inputs and default key discovery as `ssh_exec` +* Use the remote account's default shell; do not write shell setup commands into the PTY after startup --- @@ -553,6 +555,9 @@ The MVP must include: * terminal_write * terminal_read * terminal_stop +* sftp_list +* sftp_get +* sftp_put * basic audit logging --- diff --git a/src/McpSsh.Server/Program.cs b/src/McpSsh.Server/Program.cs index 8bf68d4..3e1829a 100644 --- a/src/McpSsh.Server/Program.cs +++ b/src/McpSsh.Server/Program.cs @@ -1,5 +1,6 @@ using McpSsh.Server; using McpSsh.Server.Audit; +using McpSsh.Server.Sftp; using McpSsh.Server.Ssh; using McpSsh.Server.Terminal; using McpSsh.Server.Tools; @@ -14,11 +15,13 @@ builder.Logging.ClearProviders(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); +builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); builder.Services.AddSingleton(); +builder.Services.AddSingleton(); builder.Services .AddMcpServer() diff --git a/src/McpSsh.Server/Sftp/SftpModels.cs b/src/McpSsh.Server/Sftp/SftpModels.cs new file mode 100644 index 0000000..1a5262d --- /dev/null +++ b/src/McpSsh.Server/Sftp/SftpModels.cs @@ -0,0 +1,22 @@ +namespace McpSsh.Server.Sftp; + +public sealed record SftpEntry( + string Name, + string Path, + string Type, + long Size, + DateTime ModifiedUtc); + +public sealed record SftpListResult( + IReadOnlyList? Entries, + string? Error = null, + string? Message = null, + bool ScpFallbackAvailable = false); + +public sealed record SftpTransferResult( + long BytesTransferred, + string? LocalPath, + string? RemotePath, + string TransferProtocol, + string? Error = null, + string? Message = null); diff --git a/src/McpSsh.Server/Sftp/SftpService.cs b/src/McpSsh.Server/Sftp/SftpService.cs new file mode 100644 index 0000000..38ff87e --- /dev/null +++ b/src/McpSsh.Server/Sftp/SftpService.cs @@ -0,0 +1,418 @@ +using McpSsh.Server.Audit; +using McpSsh.Server.Ssh; +using Renci.SshNet.Common; + +namespace McpSsh.Server.Sftp; + +public sealed class SftpService +{ + public const int DefaultPort = 22; + public const long DefaultMaxTransferBytes = 100L * 1024 * 1024; + public const long MaxTransferBytesLimit = 10L * 1024 * 1024 * 1024; + + private readonly ISshClientFactory _clientFactory; + private readonly IAuditLogger _auditLogger; + private readonly ISystemClock _clock; + + public SftpService(ISshClientFactory clientFactory, IAuditLogger auditLogger, ISystemClock clock) + { + _clientFactory = clientFactory; + _auditLogger = auditLogger; + _clock = clock; + } + + public Task ListAsync( + string host, + string username, + string remotePath, + int? port, + string? keyPath, + string? keyPassphrase, + CancellationToken cancellationToken) + { + var request = ValidateRequest(host, username, port, keyPath, keyPassphrase); + if (string.IsNullOrWhiteSpace(remotePath)) + { + return Task.FromResult(new SftpListResult(null, "invalid_remote_path", "Remote path is required.")); + } + + return Task.Run(() => + { + try + { + using var client = _clientFactory.CreateSftpClient(request); + client.Connect(); + cancellationToken.ThrowIfCancellationRequested(); + + var entries = client + .ListDirectory(remotePath) + .Where(file => file.Name is not "." and not "..") + .Select(file => new SftpEntry( + file.Name, + CombineRemotePath(remotePath, file.Name), + file.IsDirectory ? "directory" : file.IsSymbolicLink ? "symlink" : "file", + file.Length, + file.LastWriteTimeUtc)) + .ToArray(); + + Log("sftp_list", request.Host, success: true, command: remotePath); + return new SftpListResult(entries); + } + catch (Exception ex) when (TryMapSshException(ex, out var error, out var message)) + { + Log("sftp_list", request.Host, success: false, errorCode: error, message: message, command: remotePath); + return new SftpListResult(null, error, message, ScpFallbackAvailable: false); + } + }, cancellationToken); + } + + public Task GetAsync( + string host, + string username, + string remotePath, + string localPath, + int? port, + string? keyPath, + string? keyPassphrase, + bool? overwrite, + long? maxBytes, + CancellationToken cancellationToken) + { + var request = ValidateRequest(host, username, port, keyPath, keyPassphrase); + var transfer = ValidateTransferPaths(remotePath, localPath, maxBytes); + var resolvedLocalPath = ResolveLocalPath(transfer.LocalPath); + var allowOverwrite = overwrite ?? false; + + if (File.Exists(resolvedLocalPath) && !allowOverwrite) + { + return Task.FromResult(new SftpTransferResult(0, resolvedLocalPath, transfer.RemotePath, "none", "local_file_exists", "Local file already exists.")); + } + + return Task.Run(() => + { + Directory.CreateDirectory(Path.GetDirectoryName(resolvedLocalPath) ?? Directory.GetCurrentDirectory()); + var tempPath = Path.Combine(Path.GetDirectoryName(resolvedLocalPath) ?? Directory.GetCurrentDirectory(), $".{Path.GetFileName(resolvedLocalPath)}.{Guid.NewGuid():N}.tmp"); + + try + { + try + { + var bytes = DownloadWithSftp(request, transfer.RemotePath, tempPath, transfer.MaxBytes, cancellationToken); + MoveDownloadedFile(tempPath, resolvedLocalPath, allowOverwrite); + Log("sftp_get", request.Host, success: true, command: transfer.RemotePath); + return new SftpTransferResult(bytes, resolvedLocalPath, transfer.RemotePath, "sftp"); + } + catch (Exception ex) when (IsSftpFallbackCandidate(ex)) + { + var bytes = DownloadWithScp(request, transfer.RemotePath, tempPath, transfer.MaxBytes, cancellationToken); + MoveDownloadedFile(tempPath, resolvedLocalPath, allowOverwrite); + Log("sftp_get", request.Host, success: true, command: transfer.RemotePath); + return new SftpTransferResult(bytes, resolvedLocalPath, transfer.RemotePath, "scp"); + } + } + catch (Exception ex) when (TryMapSshException(ex, out var error, out var message)) + { + DeleteQuietly(tempPath); + Log("sftp_get", request.Host, success: false, errorCode: error, message: message, command: transfer.RemotePath); + return new SftpTransferResult(0, resolvedLocalPath, transfer.RemotePath, "none", error, message); + } + }, cancellationToken); + } + + public Task PutAsync( + string host, + string username, + string localPath, + string remotePath, + int? port, + string? keyPath, + string? keyPassphrase, + bool? overwrite, + long? maxBytes, + CancellationToken cancellationToken) + { + return Task.Run(() => + { + SshConnectionRequest? request = null; + string? resolvedLocalPath = null; + string? resolvedRemotePath = null; + + try + { + request = ValidateRequest(host, username, port, keyPath, keyPassphrase); + var transfer = ValidateTransferPaths(remotePath, localPath, maxBytes); + resolvedRemotePath = transfer.RemotePath; + resolvedLocalPath = ResolveExistingLocalFile(transfer.LocalPath, transfer.MaxBytes); + var allowOverwrite = overwrite ?? false; + + try + { + var bytes = UploadWithSftp(request, resolvedLocalPath, resolvedRemotePath, allowOverwrite, cancellationToken); + Log("sftp_put", request.Host, success: true, command: resolvedRemotePath); + return new SftpTransferResult(bytes, resolvedLocalPath, resolvedRemotePath, "sftp"); + } + catch (Exception ex) when (IsSftpFallbackCandidate(ex)) + { + var bytes = UploadWithScp(request, resolvedLocalPath, resolvedRemotePath, allowOverwrite, cancellationToken); + Log("sftp_put", request.Host, success: true, command: resolvedRemotePath); + return new SftpTransferResult(bytes, resolvedLocalPath, resolvedRemotePath, "scp"); + } + } + catch (Exception ex) when (TryMapSshException(ex, out var error, out var message)) + { + Log("sftp_put", request?.Host ?? host, success: false, errorCode: error, message: message, command: resolvedRemotePath ?? remotePath); + return new SftpTransferResult(0, resolvedLocalPath, resolvedRemotePath ?? remotePath, "none", error, message); + } + }, cancellationToken); + } + + private long DownloadWithSftp(SshConnectionRequest request, string remotePath, string tempPath, long maxBytes, CancellationToken cancellationToken) + { + using var client = _clientFactory.CreateSftpClient(request); + client.Connect(); + cancellationToken.ThrowIfCancellationRequested(); + + var attributes = client.GetAttributes(remotePath); + if (attributes.IsDirectory) + { + throw new SshToolException("remote_path_is_directory", "Remote path is a directory."); + } + + if (attributes.Size > maxBytes) + { + throw new SshToolException("download_too_large", $"Remote file size {attributes.Size} exceeds the configured maximum of {maxBytes} bytes."); + } + + using var output = File.Create(tempPath); + client.DownloadFile(remotePath, output); + return output.Length; + } + + private long DownloadWithScp(SshConnectionRequest request, string remotePath, string tempPath, long maxBytes, CancellationToken cancellationToken) + { + using var client = _clientFactory.CreateScpClient(request); + client.Connect(); + cancellationToken.ThrowIfCancellationRequested(); + + using var output = File.Create(tempPath); + client.Download(remotePath, output); + output.Flush(); + if (output.Length > maxBytes) + { + throw new SshToolException("download_too_large", $"Downloaded file size {output.Length} exceeds the configured maximum of {maxBytes} bytes."); + } + + return output.Length; + } + + private long UploadWithSftp(SshConnectionRequest request, string localPath, string remotePath, bool overwrite, CancellationToken cancellationToken) + { + using var client = _clientFactory.CreateSftpClient(request); + client.Connect(); + cancellationToken.ThrowIfCancellationRequested(); + + if (!overwrite && client.Exists(remotePath)) + { + throw new SshToolException("remote_file_exists", "Remote file already exists."); + } + + using var input = File.OpenRead(localPath); + client.UploadFile(input, remotePath, overwrite); + return input.Length; + } + + private long UploadWithScp(SshConnectionRequest request, string localPath, string remotePath, bool overwrite, CancellationToken cancellationToken) + { + if (!overwrite && CheckRemoteExistsWithSsh(request, remotePath, cancellationToken)) + { + throw new SshToolException("remote_file_exists", "Remote file already exists."); + } + + using var client = _clientFactory.CreateScpClient(request); + client.Connect(); + cancellationToken.ThrowIfCancellationRequested(); + + var file = new FileInfo(localPath); + client.Upload(file, remotePath); + return file.Length; + } + + private bool CheckRemoteExistsWithSsh(SshConnectionRequest request, string remotePath, CancellationToken cancellationToken) + { + using var client = _clientFactory.CreateSshClient(request); + client.Connect(); + cancellationToken.ThrowIfCancellationRequested(); + using var command = client.CreateCommand($"test -e {RemoteShellCommand.Quote(remotePath)}"); + command.Execute(); + return command.ExitStatus == 0; + } + + private static SshConnectionRequest ValidateRequest(string host, string username, int? port, string? keyPath, string? keyPassphrase) + { + if (string.IsNullOrWhiteSpace(host)) + { + throw new SshToolException("invalid_host", "Host is required."); + } + + if (string.IsNullOrWhiteSpace(username)) + { + throw new SshToolException("invalid_username", "Username is required."); + } + + var resolvedPort = port ?? DefaultPort; + if (resolvedPort is < 1 or > 65535) + { + throw new SshToolException("invalid_port", "Port must be between 1 and 65535."); + } + + return new SshConnectionRequest( + host.Trim(), + username.Trim(), + resolvedPort, + string.IsNullOrWhiteSpace(keyPath) ? null : keyPath.Trim(), + string.IsNullOrEmpty(keyPassphrase) ? null : keyPassphrase); + } + + private static (string RemotePath, string LocalPath, long MaxBytes) ValidateTransferPaths(string remotePath, string localPath, long? maxBytes) + { + if (string.IsNullOrWhiteSpace(remotePath)) + { + throw new SshToolException("invalid_remote_path", "Remote path is required."); + } + + if (string.IsNullOrWhiteSpace(localPath)) + { + throw new SshToolException("invalid_local_path", "Local path is required."); + } + + var resolvedMaxBytes = maxBytes ?? DefaultMaxTransferBytes; + if (resolvedMaxBytes is < 1 or > MaxTransferBytesLimit) + { + throw new SshToolException("invalid_max_bytes", $"maxBytes must be between 1 and {MaxTransferBytesLimit}."); + } + + return (remotePath.Trim(), localPath.Trim(), resolvedMaxBytes); + } + + private static string ResolveExistingLocalFile(string localPath, long maxBytes) + { + var resolved = ResolveLocalPath(localPath); + if (!File.Exists(resolved)) + { + throw new SshToolException("local_file_not_found", "Local file does not exist."); + } + + var file = new FileInfo(resolved); + if (file.Length > maxBytes) + { + throw new SshToolException("upload_too_large", $"Local file size {file.Length} exceeds the configured maximum of {maxBytes} bytes."); + } + + return resolved; + } + + private static string ResolveLocalPath(string localPath) + { + var expanded = ExpandHomeDirectory(localPath); + var fullPath = Path.GetFullPath(expanded); + var currentDirectory = Path.GetFullPath(Directory.GetCurrentDirectory()); + if (!fullPath.StartsWith(currentDirectory + Path.DirectorySeparatorChar, StringComparison.Ordinal) && + !string.Equals(fullPath, currentDirectory, StringComparison.Ordinal)) + { + throw new SshToolException("unsafe_local_path", "Local paths must stay within the current working directory."); + } + + return fullPath; + } + + private static string ExpandHomeDirectory(string path) + { + if (path == "~") + { + return Environment.GetFolderPath(Environment.SpecialFolder.UserProfile); + } + + if (path.StartsWith("~/", StringComparison.Ordinal)) + { + return Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.UserProfile), path[2..]); + } + + return path; + } + + private static void MoveDownloadedFile(string tempPath, string targetPath, bool overwrite) + { + File.Move(tempPath, targetPath, overwrite); + } + + private static void DeleteQuietly(string path) + { + try + { + if (File.Exists(path)) + { + File.Delete(path); + } + } + catch + { + } + } + + private static bool IsSftpFallbackCandidate(Exception ex) + { + return ex is SshException or SftpPathNotFoundException or SftpPermissionDeniedException; + } + + private static bool TryMapSshException(Exception ex, out string error, out string message) + { + switch (ex) + { + case SshToolException toolException: + error = toolException.ErrorCode; + message = toolException.Message; + return true; + case SshAuthenticationException: + error = "ssh_authentication_failed"; + message = "SSH key authentication failed. Check the username, key path, and key passphrase."; + return true; + case SftpPermissionDeniedException: + error = "permission_denied"; + message = ex.Message; + return true; + case SftpPathNotFoundException: + error = "path_not_found"; + message = ex.Message; + return true; + case SshConnectionException: + case SshException: + error = "ssh_error"; + message = ex.Message; + return true; + case IOException: + case UnauthorizedAccessException: + error = "file_error"; + message = ex.Message; + return true; + default: + error = "transfer_error"; + message = ex.Message; + return true; + } + } + + private static string CombineRemotePath(string directory, string name) + { + if (directory.EndsWith('/')) + { + return directory + name; + } + + return directory + "/" + name; + } + + private void Log(string tool, string? host, bool success, string? errorCode = null, string? message = null, string? command = null) + { + _auditLogger.Log(new AuditEvent(_clock.UtcNow, tool, host, null, command, success, 0, errorCode, message)); + } +} diff --git a/src/McpSsh.Server/Ssh/SshClientFactory.cs b/src/McpSsh.Server/Ssh/SshClientFactory.cs new file mode 100644 index 0000000..5f9d96a --- /dev/null +++ b/src/McpSsh.Server/Ssh/SshClientFactory.cs @@ -0,0 +1,70 @@ +using Renci.SshNet; + +namespace McpSsh.Server.Ssh; + +public interface ISshClientFactory +{ + SshClient CreateSshClient(SshConnectionRequest request); + + SftpClient CreateSftpClient(SshConnectionRequest request); + + ScpClient CreateScpClient(SshConnectionRequest request); +} + +public sealed record SshConnectionRequest( + string Host, + string Username, + int Port, + string? KeyPath, + string? KeyPassphrase, + TimeSpan? Timeout = null); + +public sealed class SshNetClientFactory : ISshClientFactory +{ + private readonly ISshKeyResolver _keyResolver; + + public SshNetClientFactory(ISshKeyResolver keyResolver) + { + _keyResolver = keyResolver; + } + + public SshClient CreateSshClient(SshConnectionRequest request) + { + var client = new SshClient(CreateConnectionInfo(request)); + ApplyTimeout(client, request.Timeout); + return client; + } + + public SftpClient CreateSftpClient(SshConnectionRequest request) + { + var client = new SftpClient(CreateConnectionInfo(request)); + ApplyTimeout(client, request.Timeout); + return client; + } + + public ScpClient CreateScpClient(SshConnectionRequest request) + { + var client = new ScpClient(CreateConnectionInfo(request)); + ApplyTimeout(client, request.Timeout); + return client; + } + + private ConnectionInfo CreateConnectionInfo(SshConnectionRequest request) + { + var keyPath = _keyResolver.ResolveKeyPath(request.KeyPath); + var keyFile = string.IsNullOrEmpty(request.KeyPassphrase) + ? new PrivateKeyFile(keyPath) + : new PrivateKeyFile(keyPath, request.KeyPassphrase); + + var auth = new PrivateKeyAuthenticationMethod(request.Username, keyFile); + return new ConnectionInfo(request.Host, request.Port, request.Username, auth); + } + + private static void ApplyTimeout(BaseClient client, TimeSpan? timeout) + { + if (timeout is { } value) + { + client.ConnectionInfo.Timeout = value; + } + } +} diff --git a/src/McpSsh.Server/Ssh/SshCommandExecutor.cs b/src/McpSsh.Server/Ssh/SshCommandExecutor.cs index 0505dbc..cd74eb0 100644 --- a/src/McpSsh.Server/Ssh/SshCommandExecutor.cs +++ b/src/McpSsh.Server/Ssh/SshCommandExecutor.cs @@ -10,11 +10,11 @@ public interface ISshCommandExecutor public sealed class SshNetCommandExecutor : ISshCommandExecutor { - private readonly ISshKeyResolver _keyResolver; + private readonly ISshClientFactory _clientFactory; - public SshNetCommandExecutor(ISshKeyResolver keyResolver) + public SshNetCommandExecutor(ISshClientFactory clientFactory) { - _keyResolver = keyResolver; + _clientFactory = clientFactory; } public Task ExecuteAsync(SshExecRequest request, CancellationToken cancellationToken) @@ -26,13 +26,13 @@ public sealed class SshNetCommandExecutor : ISshCommandExecutor { try { - var keyPath = _keyResolver.ResolveKeyPath(request.KeyPath); - using var keyFile = string.IsNullOrEmpty(request.KeyPassphrase) - ? new PrivateKeyFile(keyPath) - : new PrivateKeyFile(keyPath, request.KeyPassphrase); - using var client = new SshClient(request.Host, request.Port, request.Username, keyFile); - - client.ConnectionInfo.Timeout = TimeSpan.FromSeconds(request.TimeoutSeconds); + using var client = _clientFactory.CreateSshClient(new SshConnectionRequest( + request.Host, + request.Username, + request.Port, + request.KeyPath, + request.KeyPassphrase, + TimeSpan.FromSeconds(request.TimeoutSeconds))); client.Connect(); cancellationToken.ThrowIfCancellationRequested(); diff --git a/src/McpSsh.Server/Terminal/TerminalConnection.cs b/src/McpSsh.Server/Terminal/TerminalConnection.cs index bcba1ac..0ca6f5e 100644 --- a/src/McpSsh.Server/Terminal/TerminalConnection.cs +++ b/src/McpSsh.Server/Terminal/TerminalConnection.cs @@ -21,22 +21,23 @@ public interface ITerminalConnectionFactory public sealed class SshNetTerminalConnectionFactory : ITerminalConnectionFactory { private const int ShellBufferSize = 16 * 1024; - private readonly ISshKeyResolver _keyResolver; + private readonly ISshClientFactory _clientFactory; - public SshNetTerminalConnectionFactory(ISshKeyResolver keyResolver) + public SshNetTerminalConnectionFactory(ISshClientFactory clientFactory) { - _keyResolver = keyResolver; + _clientFactory = clientFactory; } public ITerminalConnection Create(TerminalStartRequest request) { try { - var keyPath = _keyResolver.ResolveKeyPath(request.KeyPath); - var keyFile = string.IsNullOrEmpty(request.KeyPassphrase) - ? new PrivateKeyFile(keyPath) - : new PrivateKeyFile(keyPath, request.KeyPassphrase); - var client = new SshClient(request.Host, request.Port, request.Username, keyFile); + var client = _clientFactory.CreateSshClient(new SshConnectionRequest( + request.Host, + request.Username, + request.Port, + request.KeyPath, + request.KeyPassphrase)); client.Connect(); @@ -48,7 +49,7 @@ public sealed class SshNetTerminalConnectionFactory : ITerminalConnectionFactory height: 0, bufferSize: ShellBufferSize); - return new SshNetTerminalConnection(client, stream, keyFile); + return new SshNetTerminalConnection(client, stream); } catch (SshAuthenticationException ex) { @@ -64,13 +65,11 @@ public sealed class SshNetTerminalConnectionFactory : ITerminalConnectionFactory { private readonly SshClient _client; private readonly ShellStream _stream; - private readonly PrivateKeyFile _keyFile; - public SshNetTerminalConnection(SshClient client, ShellStream stream, PrivateKeyFile keyFile) + public SshNetTerminalConnection(SshClient client, ShellStream stream) { _client = client; _stream = stream; - _keyFile = keyFile; } public bool DataAvailable => _stream.DataAvailable; @@ -89,7 +88,6 @@ public sealed class SshNetTerminalConnectionFactory : ITerminalConnectionFactory { _stream.Dispose(); _client.Dispose(); - _keyFile.Dispose(); } } } diff --git a/src/McpSsh.Server/Terminal/TerminalResults.cs b/src/McpSsh.Server/Terminal/TerminalResults.cs index 64a1b0c..59ee554 100644 --- a/src/McpSsh.Server/Terminal/TerminalResults.cs +++ b/src/McpSsh.Server/Terminal/TerminalResults.cs @@ -3,7 +3,6 @@ namespace McpSsh.Server.Terminal; public sealed record TerminalStartRequest( string Host, string Username, - string? Shell, int Cols, int Rows, int Port, diff --git a/src/McpSsh.Server/Terminal/TerminalSessionManager.cs b/src/McpSsh.Server/Terminal/TerminalSessionManager.cs index 691acd9..a391427 100644 --- a/src/McpSsh.Server/Terminal/TerminalSessionManager.cs +++ b/src/McpSsh.Server/Terminal/TerminalSessionManager.cs @@ -34,7 +34,6 @@ public sealed class TerminalSessionManager : IDisposable public Task StartAsync( string host, string username, - string? shell, int? cols, int? rows, int? port, @@ -43,7 +42,7 @@ public sealed class TerminalSessionManager : IDisposable int? idleTimeoutSeconds, CancellationToken cancellationToken) { - var request = ValidateStart(host, username, shell, cols, rows, port, keyPath, keyPassphrase, idleTimeoutSeconds); + var request = ValidateStart(host, username, cols, rows, port, keyPath, keyPassphrase, idleTimeoutSeconds); var started = _clock.UtcNow; try @@ -59,19 +58,14 @@ public sealed class TerminalSessionManager : IDisposable throw new SshToolException("terminal_session_conflict", "Unable to allocate a unique terminal session ID."); } - if (!string.IsNullOrWhiteSpace(request.Shell)) - { - connection.Write($"exec {request.Shell}\n"); - } - session.ReaderTask = Task.Run(() => ReadLoopAsync(session), CancellationToken.None); - Log("terminal_start", request.Host, success: true, command: request.Shell); + Log("terminal_start", request.Host, success: true); return Task.FromResult(new TerminalStartResult(sessionId)); } catch (SshToolException ex) { - Log("terminal_start", request.Host, success: false, errorCode: ex.ErrorCode, message: ex.Message, command: request.Shell); + Log("terminal_start", request.Host, success: false, errorCode: ex.ErrorCode, message: ex.Message); return Task.FromResult(new TerminalStartResult(null, ex.ErrorCode, ex.Message)); } } @@ -183,7 +177,7 @@ public sealed class TerminalSessionManager : IDisposable } } - private static TerminalStartRequest ValidateStart(string host, string username, string? shell, int? cols, int? rows, int? port, string? keyPath, string? keyPassphrase, int? idleTimeoutSeconds) + private static TerminalStartRequest ValidateStart(string host, string username, int? cols, int? rows, int? port, string? keyPath, string? keyPassphrase, int? idleTimeoutSeconds) { if (string.IsNullOrWhiteSpace(host)) { @@ -222,7 +216,6 @@ public sealed class TerminalSessionManager : IDisposable return new TerminalStartRequest( host.Trim(), username.Trim(), - string.IsNullOrWhiteSpace(shell) ? null : shell.Trim(), resolvedCols, resolvedRows, resolvedPort, diff --git a/src/McpSsh.Server/Tools/SshTools.cs b/src/McpSsh.Server/Tools/SshTools.cs index e1696ae..dd11137 100644 --- a/src/McpSsh.Server/Tools/SshTools.cs +++ b/src/McpSsh.Server/Tools/SshTools.cs @@ -1,4 +1,5 @@ using System.ComponentModel; +using McpSsh.Server.Sftp; using McpSsh.Server.Ssh; using McpSsh.Server.Terminal; using ModelContextProtocol.Server; @@ -10,11 +11,13 @@ public sealed class SshTools { private readonly SshExecService _sshExecService; private readonly TerminalSessionManager _terminalSessionManager; + private readonly SftpService _sftpService; - public SshTools(SshExecService sshExecService, TerminalSessionManager terminalSessionManager) + public SshTools(SshExecService sshExecService, TerminalSessionManager terminalSessionManager, SftpService sftpService) { _sshExecService = sshExecService; _terminalSessionManager = terminalSessionManager; + _sftpService = sftpService; } [McpServerTool(Name = "ssh_exec", Destructive = true)] @@ -38,7 +41,6 @@ public sealed class SshTools public Task StartTerminalAsync( [Description("Remote hostname or IP address. OpenSSH aliases are not supported in this vertical slice.")] string host, [Description("Remote SSH username.")] string username, - [Description("Optional shell to exec after PTY allocation, for example bash or sh.")] string? shell = null, [Description("Terminal columns. Defaults to 120.")] int? cols = null, [Description("Terminal rows. Defaults to 40.")] int? rows = null, [Description("Remote SSH port. Defaults to 22.")] int? port = null, @@ -47,7 +49,7 @@ public sealed class SshTools [Description("Idle timeout in seconds. Defaults to 900.")] int? idleTimeoutSeconds = null, CancellationToken cancellationToken = default) { - return _terminalSessionManager.StartAsync(host, username, shell, cols, rows, port, keyPath, keyPassphrase, idleTimeoutSeconds, cancellationToken); + return _terminalSessionManager.StartAsync(host, username, cols, rows, port, keyPath, keyPassphrase, idleTimeoutSeconds, cancellationToken); } [McpServerTool(Name = "terminal_write", Destructive = true)] @@ -75,4 +77,52 @@ public sealed class SshTools { return _terminalSessionManager.Stop(sessionId); } + + [McpServerTool(Name = "sftp_list", Destructive = false)] + [Description("List remote directory contents over SFTP.")] + public Task ListSftpAsync( + [Description("Remote hostname or IP address. OpenSSH aliases are not supported in this vertical slice.")] string host, + [Description("Remote SSH username.")] string username, + [Description("Remote directory path to list.")] string remotePath, + [Description("Remote SSH port. Defaults to 22.")] int? port = null, + [Description("Optional local private key path. Defaults to ~/.ssh/id_ed25519, ~/.ssh/id_ecdsa, then ~/.ssh/id_rsa.")] string? keyPath = null, + [Description("Optional private key passphrase. Use only with trusted MCP clients.")] string? keyPassphrase = null, + CancellationToken cancellationToken = default) + { + return _sftpService.ListAsync(host, username, remotePath, port, keyPath, keyPassphrase, cancellationToken); + } + + [McpServerTool(Name = "sftp_get", Destructive = false)] + [Description("Download a remote file using SFTP, silently falling back to SCP when SFTP is unavailable.")] + public Task GetSftpAsync( + [Description("Remote hostname or IP address. OpenSSH aliases are not supported in this vertical slice.")] string host, + [Description("Remote SSH username.")] string username, + [Description("Remote file path to download.")] string remotePath, + [Description("Local destination path. Must stay within the server working directory.")] string localPath, + [Description("Remote SSH port. Defaults to 22.")] int? port = null, + [Description("Optional local private key path. Defaults to ~/.ssh/id_ed25519, ~/.ssh/id_ecdsa, then ~/.ssh/id_rsa.")] string? keyPath = null, + [Description("Optional private key passphrase. Use only with trusted MCP clients.")] string? keyPassphrase = null, + [Description("Overwrite an existing local file. Defaults to false.")] bool? overwrite = null, + [Description("Maximum download size in bytes. Defaults to 104857600.")] long? maxBytes = null, + CancellationToken cancellationToken = default) + { + return _sftpService.GetAsync(host, username, remotePath, localPath, port, keyPath, keyPassphrase, overwrite, maxBytes, cancellationToken); + } + + [McpServerTool(Name = "sftp_put", Destructive = true)] + [Description("Upload a local file using SFTP, silently falling back to SCP when SFTP is unavailable.")] + public Task PutSftpAsync( + [Description("Remote hostname or IP address. OpenSSH aliases are not supported in this vertical slice.")] string host, + [Description("Remote SSH username.")] string username, + [Description("Local file path to upload. Must stay within the server working directory.")] string localPath, + [Description("Remote destination file path.")] string remotePath, + [Description("Remote SSH port. Defaults to 22.")] int? port = null, + [Description("Optional local private key path. Defaults to ~/.ssh/id_ed25519, ~/.ssh/id_ecdsa, then ~/.ssh/id_rsa.")] string? keyPath = null, + [Description("Optional private key passphrase. Use only with trusted MCP clients.")] string? keyPassphrase = null, + [Description("Overwrite an existing remote file. Defaults to false.")] bool? overwrite = null, + [Description("Maximum upload size in bytes. Defaults to 104857600.")] long? maxBytes = null, + CancellationToken cancellationToken = default) + { + return _sftpService.PutAsync(host, username, localPath, remotePath, port, keyPath, keyPassphrase, overwrite, maxBytes, cancellationToken); + } } diff --git a/tests/McpSsh.Tests/SftpServiceTests.cs b/tests/McpSsh.Tests/SftpServiceTests.cs new file mode 100644 index 0000000..fc68ecc --- /dev/null +++ b/tests/McpSsh.Tests/SftpServiceTests.cs @@ -0,0 +1,75 @@ +using McpSsh.Server; +using McpSsh.Server.Audit; +using McpSsh.Server.Sftp; +using McpSsh.Server.Ssh; +using Renci.SshNet; + +namespace McpSsh.Tests; + +public sealed class SftpServiceTests +{ + [Fact] + public async Task PutAsync_ReturnsErrorWhenLocalFileIsOutsideWorkingDirectory() + { + var service = CreateService(); + + var result = await service.PutAsync("host", "user", "/tmp/file.txt", "/tmp/file.txt", null, null, null, null, null, CancellationToken.None); + + Assert.Equal("unsafe_local_path", result.Error); + } + + [Fact] + public async Task PutAsync_ReturnsErrorWhenLocalFileIsMissing() + { + var service = CreateService(); + + var result = await service.PutAsync("host", "user", "missing-file.txt", "/tmp/file.txt", null, null, null, null, null, CancellationToken.None); + + Assert.Equal("local_file_not_found", result.Error); + } + + [Fact] + public async Task GetAsync_ReturnsErrorWhenLocalFileExistsAndOverwriteIsFalse() + { + var path = Path.Combine("sftp-test-existing.txt"); + await File.WriteAllTextAsync(path, "existing"); + try + { + var service = CreateService(); + + var result = await service.GetAsync("host", "user", "/tmp/file.txt", path, null, null, null, overwrite: false, maxBytes: null, CancellationToken.None); + + Assert.Equal("local_file_exists", result.Error); + } + finally + { + File.Delete(path); + } + } + + private static SftpService CreateService() + { + return new SftpService(new ThrowingClientFactory(), new CapturingAuditLogger(), new FixedClock()); + } + + private sealed class ThrowingClientFactory : ISshClientFactory + { + public SshClient CreateSshClient(SshConnectionRequest request) => throw new NotSupportedException(); + + public SftpClient CreateSftpClient(SshConnectionRequest request) => throw new NotSupportedException(); + + public ScpClient CreateScpClient(SshConnectionRequest request) => throw new NotSupportedException(); + } + + private sealed class CapturingAuditLogger : IAuditLogger + { + public void Log(AuditEvent auditEvent) + { + } + } + + private sealed class FixedClock : ISystemClock + { + public DateTimeOffset UtcNow => DateTimeOffset.Parse("2026-05-24T12:00:00Z"); + } +} diff --git a/tests/McpSsh.Tests/TerminalSessionManagerTests.cs b/tests/McpSsh.Tests/TerminalSessionManagerTests.cs index 4a9ef1c..32a06be 100644 --- a/tests/McpSsh.Tests/TerminalSessionManagerTests.cs +++ b/tests/McpSsh.Tests/TerminalSessionManagerTests.cs @@ -8,11 +8,11 @@ namespace McpSsh.Tests; public sealed class TerminalSessionManagerTests { [Fact] - public async Task StartAsync_CreatesSessionAndExecsRequestedShell() + public async Task StartAsync_CreatesSessionWithoutWritingShellSetup() { using var manager = CreateManager(out var factory, out _); - var result = await manager.StartAsync(" prod-api ", " deploy ", "bash", null, null, null, "/keys/id", "secret", null, CancellationToken.None); + var result = await manager.StartAsync(" prod-api ", " deploy ", null, null, null, "/keys/id", "secret", null, CancellationToken.None); Assert.Null(result.Error); Assert.StartsWith("term_", result.SessionId); @@ -21,14 +21,14 @@ public sealed class TerminalSessionManagerTests Assert.Equal("deploy", factory.Request.Username); Assert.Equal("/keys/id", factory.Request.KeyPath); Assert.Equal("secret", factory.Request.KeyPassphrase); - Assert.Contains("exec bash\n", factory.Connection.Writes); + Assert.Empty(factory.Connection.Writes); } [Fact] public async Task Write_SendsInputToActiveSession() { using var manager = CreateManager(out var factory, out _); - var start = await manager.StartAsync("prod-api", "deploy", null, null, null, null, null, null, null, CancellationToken.None); + var start = await manager.StartAsync("prod-api", "deploy", null, null, null, null, null, null, CancellationToken.None); var result = manager.Write(start.SessionId!, "uptime\n"); @@ -40,7 +40,7 @@ public sealed class TerminalSessionManagerTests public async Task Read_DrainsBufferedOutputAndReportsTruncation() { using var manager = CreateManager(out var factory, out _); - var start = await manager.StartAsync("prod-api", "deploy", null, null, null, null, null, null, null, CancellationToken.None); + var start = await manager.StartAsync("prod-api", "deploy", null, null, null, null, null, null, CancellationToken.None); factory.Connection.QueueOutput("abcdef"); var first = await ReadUntilOutputAsync(manager, start.SessionId!, 3); @@ -67,7 +67,7 @@ public sealed class TerminalSessionManagerTests public async Task Stop_DisposesAndRemovesSession() { using var manager = CreateManager(out var factory, out _); - var start = await manager.StartAsync("prod-api", "deploy", null, null, null, null, null, null, null, CancellationToken.None); + var start = await manager.StartAsync("prod-api", "deploy", null, null, null, null, null, null, CancellationToken.None); var result = manager.Stop(start.SessionId!); var writeAfterStop = manager.Write(start.SessionId!, "pwd\n");