Add SFTP tools with SCP fallback

This commit is contained in:
Vibe Myass
2026-05-24 21:18:09 +00:00
parent d3b39c590a
commit 8afa6dee62
12 changed files with 684 additions and 51 deletions

View File

@@ -30,10 +30,12 @@ The first implementation pass intentionally narrows the MVP:
* .NET 10 stdio MCP server
* `ssh_exec`
* `terminal_start`, `terminal_write`, `terminal_read`, and `terminal_stop`
* `sftp_list`, `sftp_get`, and `sftp_put`
* SSH.NET-based SSH connections
* Tool-supplied `host`, `username`, optional `port`, optional `keyPath`, and optional `keyPassphrase`
* Default key discovery from `~/.ssh/id_ed25519`, `~/.ssh/id_ecdsa`, then `~/.ssh/id_rsa` when `keyPath` is omitted
* No SSH agent, OpenSSH alias, `ssh -G`, ProxyJump, ProxyCommand, or SFTP support yet
* SFTP first for file transfer; `sftp_get` and `sftp_put` silently fall back to SCP if SFTP is unavailable
* No SSH agent, OpenSSH alias, `ssh -G`, ProxyJump, or ProxyCommand support yet
* Basic audit logging and timeout enforcement
## Non-Goals
@@ -169,21 +171,21 @@ Priority order:
```text
1. SFTP
2. SCP fallback (optional future enhancement)
2. SCP fallback for get/put
```
The implementation must:
* Validate SFTP subsystem availability during connection
* Return structured capability errors if unavailable
* Not silently downgrade to SCP
* Use SFTP for directory listing
* Use SFTP first for file get/put
* Silently fall back to SCP for file get/put if SFTP is unavailable
Example error:
```json
{
"error": "sftp_unavailable",
"message": "Remote host does not expose the SFTP subsystem.",
"message": "Remote host does not expose the SFTP subsystem for directory listing.",
"scpFallbackAvailable": false
}
```
@@ -243,7 +245,6 @@ Input:
{
"host": "prod-api.example.com",
"username": "deploy",
"shell": "bash",
"cols": 120,
"rows": 40,
"port": 22,
@@ -267,6 +268,7 @@ Requirements:
* Maintain server-side session state
* Support idle timeout cleanup
* Use the same key-auth inputs and default key discovery as `ssh_exec`
* Use the remote account's default shell; do not write shell setup commands into the PTY after startup
---
@@ -553,6 +555,9 @@ The MVP must include:
* terminal_write
* terminal_read
* terminal_stop
* sftp_list
* sftp_get
* sftp_put
* basic audit logging
---

View File

@@ -1,5 +1,6 @@
using McpSsh.Server;
using McpSsh.Server.Audit;
using McpSsh.Server.Sftp;
using McpSsh.Server.Ssh;
using McpSsh.Server.Terminal;
using McpSsh.Server.Tools;
@@ -14,11 +15,13 @@ builder.Logging.ClearProviders();
builder.Services.AddSingleton<ISystemClock, SystemClock>();
builder.Services.AddSingleton<IFileSystem, LocalFileSystem>();
builder.Services.AddSingleton<ISshKeyResolver, DefaultSshKeyResolver>();
builder.Services.AddSingleton<ISshClientFactory, SshNetClientFactory>();
builder.Services.AddSingleton<IAuditLogger, JsonLineAuditLogger>();
builder.Services.AddSingleton<ISshCommandExecutor, SshNetCommandExecutor>();
builder.Services.AddSingleton<SshExecService>();
builder.Services.AddSingleton<ITerminalConnectionFactory, SshNetTerminalConnectionFactory>();
builder.Services.AddSingleton<TerminalSessionManager>();
builder.Services.AddSingleton<SftpService>();
builder.Services
.AddMcpServer()

View File

@@ -0,0 +1,22 @@
namespace McpSsh.Server.Sftp;
public sealed record SftpEntry(
string Name,
string Path,
string Type,
long Size,
DateTime ModifiedUtc);
public sealed record SftpListResult(
IReadOnlyList<SftpEntry>? Entries,
string? Error = null,
string? Message = null,
bool ScpFallbackAvailable = false);
public sealed record SftpTransferResult(
long BytesTransferred,
string? LocalPath,
string? RemotePath,
string TransferProtocol,
string? Error = null,
string? Message = null);

View File

@@ -0,0 +1,418 @@
using McpSsh.Server.Audit;
using McpSsh.Server.Ssh;
using Renci.SshNet.Common;
namespace McpSsh.Server.Sftp;
public sealed class SftpService
{
public const int DefaultPort = 22;
public const long DefaultMaxTransferBytes = 100L * 1024 * 1024;
public const long MaxTransferBytesLimit = 10L * 1024 * 1024 * 1024;
private readonly ISshClientFactory _clientFactory;
private readonly IAuditLogger _auditLogger;
private readonly ISystemClock _clock;
public SftpService(ISshClientFactory clientFactory, IAuditLogger auditLogger, ISystemClock clock)
{
_clientFactory = clientFactory;
_auditLogger = auditLogger;
_clock = clock;
}
public Task<SftpListResult> ListAsync(
string host,
string username,
string remotePath,
int? port,
string? keyPath,
string? keyPassphrase,
CancellationToken cancellationToken)
{
var request = ValidateRequest(host, username, port, keyPath, keyPassphrase);
if (string.IsNullOrWhiteSpace(remotePath))
{
return Task.FromResult(new SftpListResult(null, "invalid_remote_path", "Remote path is required."));
}
return Task.Run(() =>
{
try
{
using var client = _clientFactory.CreateSftpClient(request);
client.Connect();
cancellationToken.ThrowIfCancellationRequested();
var entries = client
.ListDirectory(remotePath)
.Where(file => file.Name is not "." and not "..")
.Select(file => new SftpEntry(
file.Name,
CombineRemotePath(remotePath, file.Name),
file.IsDirectory ? "directory" : file.IsSymbolicLink ? "symlink" : "file",
file.Length,
file.LastWriteTimeUtc))
.ToArray();
Log("sftp_list", request.Host, success: true, command: remotePath);
return new SftpListResult(entries);
}
catch (Exception ex) when (TryMapSshException(ex, out var error, out var message))
{
Log("sftp_list", request.Host, success: false, errorCode: error, message: message, command: remotePath);
return new SftpListResult(null, error, message, ScpFallbackAvailable: false);
}
}, cancellationToken);
}
public Task<SftpTransferResult> GetAsync(
string host,
string username,
string remotePath,
string localPath,
int? port,
string? keyPath,
string? keyPassphrase,
bool? overwrite,
long? maxBytes,
CancellationToken cancellationToken)
{
var request = ValidateRequest(host, username, port, keyPath, keyPassphrase);
var transfer = ValidateTransferPaths(remotePath, localPath, maxBytes);
var resolvedLocalPath = ResolveLocalPath(transfer.LocalPath);
var allowOverwrite = overwrite ?? false;
if (File.Exists(resolvedLocalPath) && !allowOverwrite)
{
return Task.FromResult(new SftpTransferResult(0, resolvedLocalPath, transfer.RemotePath, "none", "local_file_exists", "Local file already exists."));
}
return Task.Run(() =>
{
Directory.CreateDirectory(Path.GetDirectoryName(resolvedLocalPath) ?? Directory.GetCurrentDirectory());
var tempPath = Path.Combine(Path.GetDirectoryName(resolvedLocalPath) ?? Directory.GetCurrentDirectory(), $".{Path.GetFileName(resolvedLocalPath)}.{Guid.NewGuid():N}.tmp");
try
{
try
{
var bytes = DownloadWithSftp(request, transfer.RemotePath, tempPath, transfer.MaxBytes, cancellationToken);
MoveDownloadedFile(tempPath, resolvedLocalPath, allowOverwrite);
Log("sftp_get", request.Host, success: true, command: transfer.RemotePath);
return new SftpTransferResult(bytes, resolvedLocalPath, transfer.RemotePath, "sftp");
}
catch (Exception ex) when (IsSftpFallbackCandidate(ex))
{
var bytes = DownloadWithScp(request, transfer.RemotePath, tempPath, transfer.MaxBytes, cancellationToken);
MoveDownloadedFile(tempPath, resolvedLocalPath, allowOverwrite);
Log("sftp_get", request.Host, success: true, command: transfer.RemotePath);
return new SftpTransferResult(bytes, resolvedLocalPath, transfer.RemotePath, "scp");
}
}
catch (Exception ex) when (TryMapSshException(ex, out var error, out var message))
{
DeleteQuietly(tempPath);
Log("sftp_get", request.Host, success: false, errorCode: error, message: message, command: transfer.RemotePath);
return new SftpTransferResult(0, resolvedLocalPath, transfer.RemotePath, "none", error, message);
}
}, cancellationToken);
}
public Task<SftpTransferResult> PutAsync(
string host,
string username,
string localPath,
string remotePath,
int? port,
string? keyPath,
string? keyPassphrase,
bool? overwrite,
long? maxBytes,
CancellationToken cancellationToken)
{
return Task.Run(() =>
{
SshConnectionRequest? request = null;
string? resolvedLocalPath = null;
string? resolvedRemotePath = null;
try
{
request = ValidateRequest(host, username, port, keyPath, keyPassphrase);
var transfer = ValidateTransferPaths(remotePath, localPath, maxBytes);
resolvedRemotePath = transfer.RemotePath;
resolvedLocalPath = ResolveExistingLocalFile(transfer.LocalPath, transfer.MaxBytes);
var allowOverwrite = overwrite ?? false;
try
{
var bytes = UploadWithSftp(request, resolvedLocalPath, resolvedRemotePath, allowOverwrite, cancellationToken);
Log("sftp_put", request.Host, success: true, command: resolvedRemotePath);
return new SftpTransferResult(bytes, resolvedLocalPath, resolvedRemotePath, "sftp");
}
catch (Exception ex) when (IsSftpFallbackCandidate(ex))
{
var bytes = UploadWithScp(request, resolvedLocalPath, resolvedRemotePath, allowOverwrite, cancellationToken);
Log("sftp_put", request.Host, success: true, command: resolvedRemotePath);
return new SftpTransferResult(bytes, resolvedLocalPath, resolvedRemotePath, "scp");
}
}
catch (Exception ex) when (TryMapSshException(ex, out var error, out var message))
{
Log("sftp_put", request?.Host ?? host, success: false, errorCode: error, message: message, command: resolvedRemotePath ?? remotePath);
return new SftpTransferResult(0, resolvedLocalPath, resolvedRemotePath ?? remotePath, "none", error, message);
}
}, cancellationToken);
}
private long DownloadWithSftp(SshConnectionRequest request, string remotePath, string tempPath, long maxBytes, CancellationToken cancellationToken)
{
using var client = _clientFactory.CreateSftpClient(request);
client.Connect();
cancellationToken.ThrowIfCancellationRequested();
var attributes = client.GetAttributes(remotePath);
if (attributes.IsDirectory)
{
throw new SshToolException("remote_path_is_directory", "Remote path is a directory.");
}
if (attributes.Size > maxBytes)
{
throw new SshToolException("download_too_large", $"Remote file size {attributes.Size} exceeds the configured maximum of {maxBytes} bytes.");
}
using var output = File.Create(tempPath);
client.DownloadFile(remotePath, output);
return output.Length;
}
private long DownloadWithScp(SshConnectionRequest request, string remotePath, string tempPath, long maxBytes, CancellationToken cancellationToken)
{
using var client = _clientFactory.CreateScpClient(request);
client.Connect();
cancellationToken.ThrowIfCancellationRequested();
using var output = File.Create(tempPath);
client.Download(remotePath, output);
output.Flush();
if (output.Length > maxBytes)
{
throw new SshToolException("download_too_large", $"Downloaded file size {output.Length} exceeds the configured maximum of {maxBytes} bytes.");
}
return output.Length;
}
private long UploadWithSftp(SshConnectionRequest request, string localPath, string remotePath, bool overwrite, CancellationToken cancellationToken)
{
using var client = _clientFactory.CreateSftpClient(request);
client.Connect();
cancellationToken.ThrowIfCancellationRequested();
if (!overwrite && client.Exists(remotePath))
{
throw new SshToolException("remote_file_exists", "Remote file already exists.");
}
using var input = File.OpenRead(localPath);
client.UploadFile(input, remotePath, overwrite);
return input.Length;
}
private long UploadWithScp(SshConnectionRequest request, string localPath, string remotePath, bool overwrite, CancellationToken cancellationToken)
{
if (!overwrite && CheckRemoteExistsWithSsh(request, remotePath, cancellationToken))
{
throw new SshToolException("remote_file_exists", "Remote file already exists.");
}
using var client = _clientFactory.CreateScpClient(request);
client.Connect();
cancellationToken.ThrowIfCancellationRequested();
var file = new FileInfo(localPath);
client.Upload(file, remotePath);
return file.Length;
}
private bool CheckRemoteExistsWithSsh(SshConnectionRequest request, string remotePath, CancellationToken cancellationToken)
{
using var client = _clientFactory.CreateSshClient(request);
client.Connect();
cancellationToken.ThrowIfCancellationRequested();
using var command = client.CreateCommand($"test -e {RemoteShellCommand.Quote(remotePath)}");
command.Execute();
return command.ExitStatus == 0;
}
private static SshConnectionRequest ValidateRequest(string host, string username, int? port, string? keyPath, string? keyPassphrase)
{
if (string.IsNullOrWhiteSpace(host))
{
throw new SshToolException("invalid_host", "Host is required.");
}
if (string.IsNullOrWhiteSpace(username))
{
throw new SshToolException("invalid_username", "Username is required.");
}
var resolvedPort = port ?? DefaultPort;
if (resolvedPort is < 1 or > 65535)
{
throw new SshToolException("invalid_port", "Port must be between 1 and 65535.");
}
return new SshConnectionRequest(
host.Trim(),
username.Trim(),
resolvedPort,
string.IsNullOrWhiteSpace(keyPath) ? null : keyPath.Trim(),
string.IsNullOrEmpty(keyPassphrase) ? null : keyPassphrase);
}
private static (string RemotePath, string LocalPath, long MaxBytes) ValidateTransferPaths(string remotePath, string localPath, long? maxBytes)
{
if (string.IsNullOrWhiteSpace(remotePath))
{
throw new SshToolException("invalid_remote_path", "Remote path is required.");
}
if (string.IsNullOrWhiteSpace(localPath))
{
throw new SshToolException("invalid_local_path", "Local path is required.");
}
var resolvedMaxBytes = maxBytes ?? DefaultMaxTransferBytes;
if (resolvedMaxBytes is < 1 or > MaxTransferBytesLimit)
{
throw new SshToolException("invalid_max_bytes", $"maxBytes must be between 1 and {MaxTransferBytesLimit}.");
}
return (remotePath.Trim(), localPath.Trim(), resolvedMaxBytes);
}
private static string ResolveExistingLocalFile(string localPath, long maxBytes)
{
var resolved = ResolveLocalPath(localPath);
if (!File.Exists(resolved))
{
throw new SshToolException("local_file_not_found", "Local file does not exist.");
}
var file = new FileInfo(resolved);
if (file.Length > maxBytes)
{
throw new SshToolException("upload_too_large", $"Local file size {file.Length} exceeds the configured maximum of {maxBytes} bytes.");
}
return resolved;
}
private static string ResolveLocalPath(string localPath)
{
var expanded = ExpandHomeDirectory(localPath);
var fullPath = Path.GetFullPath(expanded);
var currentDirectory = Path.GetFullPath(Directory.GetCurrentDirectory());
if (!fullPath.StartsWith(currentDirectory + Path.DirectorySeparatorChar, StringComparison.Ordinal) &&
!string.Equals(fullPath, currentDirectory, StringComparison.Ordinal))
{
throw new SshToolException("unsafe_local_path", "Local paths must stay within the current working directory.");
}
return fullPath;
}
private static string ExpandHomeDirectory(string path)
{
if (path == "~")
{
return Environment.GetFolderPath(Environment.SpecialFolder.UserProfile);
}
if (path.StartsWith("~/", StringComparison.Ordinal))
{
return Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.UserProfile), path[2..]);
}
return path;
}
private static void MoveDownloadedFile(string tempPath, string targetPath, bool overwrite)
{
File.Move(tempPath, targetPath, overwrite);
}
private static void DeleteQuietly(string path)
{
try
{
if (File.Exists(path))
{
File.Delete(path);
}
}
catch
{
}
}
private static bool IsSftpFallbackCandidate(Exception ex)
{
return ex is SshException or SftpPathNotFoundException or SftpPermissionDeniedException;
}
private static bool TryMapSshException(Exception ex, out string error, out string message)
{
switch (ex)
{
case SshToolException toolException:
error = toolException.ErrorCode;
message = toolException.Message;
return true;
case SshAuthenticationException:
error = "ssh_authentication_failed";
message = "SSH key authentication failed. Check the username, key path, and key passphrase.";
return true;
case SftpPermissionDeniedException:
error = "permission_denied";
message = ex.Message;
return true;
case SftpPathNotFoundException:
error = "path_not_found";
message = ex.Message;
return true;
case SshConnectionException:
case SshException:
error = "ssh_error";
message = ex.Message;
return true;
case IOException:
case UnauthorizedAccessException:
error = "file_error";
message = ex.Message;
return true;
default:
error = "transfer_error";
message = ex.Message;
return true;
}
}
private static string CombineRemotePath(string directory, string name)
{
if (directory.EndsWith('/'))
{
return directory + name;
}
return directory + "/" + name;
}
private void Log(string tool, string? host, bool success, string? errorCode = null, string? message = null, string? command = null)
{
_auditLogger.Log(new AuditEvent(_clock.UtcNow, tool, host, null, command, success, 0, errorCode, message));
}
}

View File

@@ -0,0 +1,70 @@
using Renci.SshNet;
namespace McpSsh.Server.Ssh;
public interface ISshClientFactory
{
SshClient CreateSshClient(SshConnectionRequest request);
SftpClient CreateSftpClient(SshConnectionRequest request);
ScpClient CreateScpClient(SshConnectionRequest request);
}
public sealed record SshConnectionRequest(
string Host,
string Username,
int Port,
string? KeyPath,
string? KeyPassphrase,
TimeSpan? Timeout = null);
public sealed class SshNetClientFactory : ISshClientFactory
{
private readonly ISshKeyResolver _keyResolver;
public SshNetClientFactory(ISshKeyResolver keyResolver)
{
_keyResolver = keyResolver;
}
public SshClient CreateSshClient(SshConnectionRequest request)
{
var client = new SshClient(CreateConnectionInfo(request));
ApplyTimeout(client, request.Timeout);
return client;
}
public SftpClient CreateSftpClient(SshConnectionRequest request)
{
var client = new SftpClient(CreateConnectionInfo(request));
ApplyTimeout(client, request.Timeout);
return client;
}
public ScpClient CreateScpClient(SshConnectionRequest request)
{
var client = new ScpClient(CreateConnectionInfo(request));
ApplyTimeout(client, request.Timeout);
return client;
}
private ConnectionInfo CreateConnectionInfo(SshConnectionRequest request)
{
var keyPath = _keyResolver.ResolveKeyPath(request.KeyPath);
var keyFile = string.IsNullOrEmpty(request.KeyPassphrase)
? new PrivateKeyFile(keyPath)
: new PrivateKeyFile(keyPath, request.KeyPassphrase);
var auth = new PrivateKeyAuthenticationMethod(request.Username, keyFile);
return new ConnectionInfo(request.Host, request.Port, request.Username, auth);
}
private static void ApplyTimeout(BaseClient client, TimeSpan? timeout)
{
if (timeout is { } value)
{
client.ConnectionInfo.Timeout = value;
}
}
}

View File

@@ -10,11 +10,11 @@ public interface ISshCommandExecutor
public sealed class SshNetCommandExecutor : ISshCommandExecutor
{
private readonly ISshKeyResolver _keyResolver;
private readonly ISshClientFactory _clientFactory;
public SshNetCommandExecutor(ISshKeyResolver keyResolver)
public SshNetCommandExecutor(ISshClientFactory clientFactory)
{
_keyResolver = keyResolver;
_clientFactory = clientFactory;
}
public Task<SshExecResult> ExecuteAsync(SshExecRequest request, CancellationToken cancellationToken)
@@ -26,13 +26,13 @@ public sealed class SshNetCommandExecutor : ISshCommandExecutor
{
try
{
var keyPath = _keyResolver.ResolveKeyPath(request.KeyPath);
using var keyFile = string.IsNullOrEmpty(request.KeyPassphrase)
? new PrivateKeyFile(keyPath)
: new PrivateKeyFile(keyPath, request.KeyPassphrase);
using var client = new SshClient(request.Host, request.Port, request.Username, keyFile);
client.ConnectionInfo.Timeout = TimeSpan.FromSeconds(request.TimeoutSeconds);
using var client = _clientFactory.CreateSshClient(new SshConnectionRequest(
request.Host,
request.Username,
request.Port,
request.KeyPath,
request.KeyPassphrase,
TimeSpan.FromSeconds(request.TimeoutSeconds)));
client.Connect();
cancellationToken.ThrowIfCancellationRequested();

View File

@@ -21,22 +21,23 @@ public interface ITerminalConnectionFactory
public sealed class SshNetTerminalConnectionFactory : ITerminalConnectionFactory
{
private const int ShellBufferSize = 16 * 1024;
private readonly ISshKeyResolver _keyResolver;
private readonly ISshClientFactory _clientFactory;
public SshNetTerminalConnectionFactory(ISshKeyResolver keyResolver)
public SshNetTerminalConnectionFactory(ISshClientFactory clientFactory)
{
_keyResolver = keyResolver;
_clientFactory = clientFactory;
}
public ITerminalConnection Create(TerminalStartRequest request)
{
try
{
var keyPath = _keyResolver.ResolveKeyPath(request.KeyPath);
var keyFile = string.IsNullOrEmpty(request.KeyPassphrase)
? new PrivateKeyFile(keyPath)
: new PrivateKeyFile(keyPath, request.KeyPassphrase);
var client = new SshClient(request.Host, request.Port, request.Username, keyFile);
var client = _clientFactory.CreateSshClient(new SshConnectionRequest(
request.Host,
request.Username,
request.Port,
request.KeyPath,
request.KeyPassphrase));
client.Connect();
@@ -48,7 +49,7 @@ public sealed class SshNetTerminalConnectionFactory : ITerminalConnectionFactory
height: 0,
bufferSize: ShellBufferSize);
return new SshNetTerminalConnection(client, stream, keyFile);
return new SshNetTerminalConnection(client, stream);
}
catch (SshAuthenticationException ex)
{
@@ -64,13 +65,11 @@ public sealed class SshNetTerminalConnectionFactory : ITerminalConnectionFactory
{
private readonly SshClient _client;
private readonly ShellStream _stream;
private readonly PrivateKeyFile _keyFile;
public SshNetTerminalConnection(SshClient client, ShellStream stream, PrivateKeyFile keyFile)
public SshNetTerminalConnection(SshClient client, ShellStream stream)
{
_client = client;
_stream = stream;
_keyFile = keyFile;
}
public bool DataAvailable => _stream.DataAvailable;
@@ -89,7 +88,6 @@ public sealed class SshNetTerminalConnectionFactory : ITerminalConnectionFactory
{
_stream.Dispose();
_client.Dispose();
_keyFile.Dispose();
}
}
}

View File

@@ -3,7 +3,6 @@ namespace McpSsh.Server.Terminal;
public sealed record TerminalStartRequest(
string Host,
string Username,
string? Shell,
int Cols,
int Rows,
int Port,

View File

@@ -34,7 +34,6 @@ public sealed class TerminalSessionManager : IDisposable
public Task<TerminalStartResult> StartAsync(
string host,
string username,
string? shell,
int? cols,
int? rows,
int? port,
@@ -43,7 +42,7 @@ public sealed class TerminalSessionManager : IDisposable
int? idleTimeoutSeconds,
CancellationToken cancellationToken)
{
var request = ValidateStart(host, username, shell, cols, rows, port, keyPath, keyPassphrase, idleTimeoutSeconds);
var request = ValidateStart(host, username, cols, rows, port, keyPath, keyPassphrase, idleTimeoutSeconds);
var started = _clock.UtcNow;
try
@@ -59,19 +58,14 @@ public sealed class TerminalSessionManager : IDisposable
throw new SshToolException("terminal_session_conflict", "Unable to allocate a unique terminal session ID.");
}
if (!string.IsNullOrWhiteSpace(request.Shell))
{
connection.Write($"exec {request.Shell}\n");
}
session.ReaderTask = Task.Run(() => ReadLoopAsync(session), CancellationToken.None);
Log("terminal_start", request.Host, success: true, command: request.Shell);
Log("terminal_start", request.Host, success: true);
return Task.FromResult(new TerminalStartResult(sessionId));
}
catch (SshToolException ex)
{
Log("terminal_start", request.Host, success: false, errorCode: ex.ErrorCode, message: ex.Message, command: request.Shell);
Log("terminal_start", request.Host, success: false, errorCode: ex.ErrorCode, message: ex.Message);
return Task.FromResult(new TerminalStartResult(null, ex.ErrorCode, ex.Message));
}
}
@@ -183,7 +177,7 @@ public sealed class TerminalSessionManager : IDisposable
}
}
private static TerminalStartRequest ValidateStart(string host, string username, string? shell, int? cols, int? rows, int? port, string? keyPath, string? keyPassphrase, int? idleTimeoutSeconds)
private static TerminalStartRequest ValidateStart(string host, string username, int? cols, int? rows, int? port, string? keyPath, string? keyPassphrase, int? idleTimeoutSeconds)
{
if (string.IsNullOrWhiteSpace(host))
{
@@ -222,7 +216,6 @@ public sealed class TerminalSessionManager : IDisposable
return new TerminalStartRequest(
host.Trim(),
username.Trim(),
string.IsNullOrWhiteSpace(shell) ? null : shell.Trim(),
resolvedCols,
resolvedRows,
resolvedPort,

View File

@@ -1,4 +1,5 @@
using System.ComponentModel;
using McpSsh.Server.Sftp;
using McpSsh.Server.Ssh;
using McpSsh.Server.Terminal;
using ModelContextProtocol.Server;
@@ -10,11 +11,13 @@ public sealed class SshTools
{
private readonly SshExecService _sshExecService;
private readonly TerminalSessionManager _terminalSessionManager;
private readonly SftpService _sftpService;
public SshTools(SshExecService sshExecService, TerminalSessionManager terminalSessionManager)
public SshTools(SshExecService sshExecService, TerminalSessionManager terminalSessionManager, SftpService sftpService)
{
_sshExecService = sshExecService;
_terminalSessionManager = terminalSessionManager;
_sftpService = sftpService;
}
[McpServerTool(Name = "ssh_exec", Destructive = true)]
@@ -38,7 +41,6 @@ public sealed class SshTools
public Task<TerminalStartResult> StartTerminalAsync(
[Description("Remote hostname or IP address. OpenSSH aliases are not supported in this vertical slice.")] string host,
[Description("Remote SSH username.")] string username,
[Description("Optional shell to exec after PTY allocation, for example bash or sh.")] string? shell = null,
[Description("Terminal columns. Defaults to 120.")] int? cols = null,
[Description("Terminal rows. Defaults to 40.")] int? rows = null,
[Description("Remote SSH port. Defaults to 22.")] int? port = null,
@@ -47,7 +49,7 @@ public sealed class SshTools
[Description("Idle timeout in seconds. Defaults to 900.")] int? idleTimeoutSeconds = null,
CancellationToken cancellationToken = default)
{
return _terminalSessionManager.StartAsync(host, username, shell, cols, rows, port, keyPath, keyPassphrase, idleTimeoutSeconds, cancellationToken);
return _terminalSessionManager.StartAsync(host, username, cols, rows, port, keyPath, keyPassphrase, idleTimeoutSeconds, cancellationToken);
}
[McpServerTool(Name = "terminal_write", Destructive = true)]
@@ -75,4 +77,52 @@ public sealed class SshTools
{
return _terminalSessionManager.Stop(sessionId);
}
[McpServerTool(Name = "sftp_list", Destructive = false)]
[Description("List remote directory contents over SFTP.")]
public Task<SftpListResult> ListSftpAsync(
[Description("Remote hostname or IP address. OpenSSH aliases are not supported in this vertical slice.")] string host,
[Description("Remote SSH username.")] string username,
[Description("Remote directory path to list.")] string remotePath,
[Description("Remote SSH port. Defaults to 22.")] int? port = null,
[Description("Optional local private key path. Defaults to ~/.ssh/id_ed25519, ~/.ssh/id_ecdsa, then ~/.ssh/id_rsa.")] string? keyPath = null,
[Description("Optional private key passphrase. Use only with trusted MCP clients.")] string? keyPassphrase = null,
CancellationToken cancellationToken = default)
{
return _sftpService.ListAsync(host, username, remotePath, port, keyPath, keyPassphrase, cancellationToken);
}
[McpServerTool(Name = "sftp_get", Destructive = false)]
[Description("Download a remote file using SFTP, silently falling back to SCP when SFTP is unavailable.")]
public Task<SftpTransferResult> GetSftpAsync(
[Description("Remote hostname or IP address. OpenSSH aliases are not supported in this vertical slice.")] string host,
[Description("Remote SSH username.")] string username,
[Description("Remote file path to download.")] string remotePath,
[Description("Local destination path. Must stay within the server working directory.")] string localPath,
[Description("Remote SSH port. Defaults to 22.")] int? port = null,
[Description("Optional local private key path. Defaults to ~/.ssh/id_ed25519, ~/.ssh/id_ecdsa, then ~/.ssh/id_rsa.")] string? keyPath = null,
[Description("Optional private key passphrase. Use only with trusted MCP clients.")] string? keyPassphrase = null,
[Description("Overwrite an existing local file. Defaults to false.")] bool? overwrite = null,
[Description("Maximum download size in bytes. Defaults to 104857600.")] long? maxBytes = null,
CancellationToken cancellationToken = default)
{
return _sftpService.GetAsync(host, username, remotePath, localPath, port, keyPath, keyPassphrase, overwrite, maxBytes, cancellationToken);
}
[McpServerTool(Name = "sftp_put", Destructive = true)]
[Description("Upload a local file using SFTP, silently falling back to SCP when SFTP is unavailable.")]
public Task<SftpTransferResult> PutSftpAsync(
[Description("Remote hostname or IP address. OpenSSH aliases are not supported in this vertical slice.")] string host,
[Description("Remote SSH username.")] string username,
[Description("Local file path to upload. Must stay within the server working directory.")] string localPath,
[Description("Remote destination file path.")] string remotePath,
[Description("Remote SSH port. Defaults to 22.")] int? port = null,
[Description("Optional local private key path. Defaults to ~/.ssh/id_ed25519, ~/.ssh/id_ecdsa, then ~/.ssh/id_rsa.")] string? keyPath = null,
[Description("Optional private key passphrase. Use only with trusted MCP clients.")] string? keyPassphrase = null,
[Description("Overwrite an existing remote file. Defaults to false.")] bool? overwrite = null,
[Description("Maximum upload size in bytes. Defaults to 104857600.")] long? maxBytes = null,
CancellationToken cancellationToken = default)
{
return _sftpService.PutAsync(host, username, localPath, remotePath, port, keyPath, keyPassphrase, overwrite, maxBytes, cancellationToken);
}
}

View File

@@ -0,0 +1,75 @@
using McpSsh.Server;
using McpSsh.Server.Audit;
using McpSsh.Server.Sftp;
using McpSsh.Server.Ssh;
using Renci.SshNet;
namespace McpSsh.Tests;
public sealed class SftpServiceTests
{
[Fact]
public async Task PutAsync_ReturnsErrorWhenLocalFileIsOutsideWorkingDirectory()
{
var service = CreateService();
var result = await service.PutAsync("host", "user", "/tmp/file.txt", "/tmp/file.txt", null, null, null, null, null, CancellationToken.None);
Assert.Equal("unsafe_local_path", result.Error);
}
[Fact]
public async Task PutAsync_ReturnsErrorWhenLocalFileIsMissing()
{
var service = CreateService();
var result = await service.PutAsync("host", "user", "missing-file.txt", "/tmp/file.txt", null, null, null, null, null, CancellationToken.None);
Assert.Equal("local_file_not_found", result.Error);
}
[Fact]
public async Task GetAsync_ReturnsErrorWhenLocalFileExistsAndOverwriteIsFalse()
{
var path = Path.Combine("sftp-test-existing.txt");
await File.WriteAllTextAsync(path, "existing");
try
{
var service = CreateService();
var result = await service.GetAsync("host", "user", "/tmp/file.txt", path, null, null, null, overwrite: false, maxBytes: null, CancellationToken.None);
Assert.Equal("local_file_exists", result.Error);
}
finally
{
File.Delete(path);
}
}
private static SftpService CreateService()
{
return new SftpService(new ThrowingClientFactory(), new CapturingAuditLogger(), new FixedClock());
}
private sealed class ThrowingClientFactory : ISshClientFactory
{
public SshClient CreateSshClient(SshConnectionRequest request) => throw new NotSupportedException();
public SftpClient CreateSftpClient(SshConnectionRequest request) => throw new NotSupportedException();
public ScpClient CreateScpClient(SshConnectionRequest request) => throw new NotSupportedException();
}
private sealed class CapturingAuditLogger : IAuditLogger
{
public void Log(AuditEvent auditEvent)
{
}
}
private sealed class FixedClock : ISystemClock
{
public DateTimeOffset UtcNow => DateTimeOffset.Parse("2026-05-24T12:00:00Z");
}
}

View File

@@ -8,11 +8,11 @@ namespace McpSsh.Tests;
public sealed class TerminalSessionManagerTests
{
[Fact]
public async Task StartAsync_CreatesSessionAndExecsRequestedShell()
public async Task StartAsync_CreatesSessionWithoutWritingShellSetup()
{
using var manager = CreateManager(out var factory, out _);
var result = await manager.StartAsync(" prod-api ", " deploy ", "bash", null, null, null, "/keys/id", "secret", null, CancellationToken.None);
var result = await manager.StartAsync(" prod-api ", " deploy ", null, null, null, "/keys/id", "secret", null, CancellationToken.None);
Assert.Null(result.Error);
Assert.StartsWith("term_", result.SessionId);
@@ -21,14 +21,14 @@ public sealed class TerminalSessionManagerTests
Assert.Equal("deploy", factory.Request.Username);
Assert.Equal("/keys/id", factory.Request.KeyPath);
Assert.Equal("secret", factory.Request.KeyPassphrase);
Assert.Contains("exec bash\n", factory.Connection.Writes);
Assert.Empty(factory.Connection.Writes);
}
[Fact]
public async Task Write_SendsInputToActiveSession()
{
using var manager = CreateManager(out var factory, out _);
var start = await manager.StartAsync("prod-api", "deploy", null, null, null, null, null, null, null, CancellationToken.None);
var start = await manager.StartAsync("prod-api", "deploy", null, null, null, null, null, null, CancellationToken.None);
var result = manager.Write(start.SessionId!, "uptime\n");
@@ -40,7 +40,7 @@ public sealed class TerminalSessionManagerTests
public async Task Read_DrainsBufferedOutputAndReportsTruncation()
{
using var manager = CreateManager(out var factory, out _);
var start = await manager.StartAsync("prod-api", "deploy", null, null, null, null, null, null, null, CancellationToken.None);
var start = await manager.StartAsync("prod-api", "deploy", null, null, null, null, null, null, CancellationToken.None);
factory.Connection.QueueOutput("abcdef");
var first = await ReadUntilOutputAsync(manager, start.SessionId!, 3);
@@ -67,7 +67,7 @@ public sealed class TerminalSessionManagerTests
public async Task Stop_DisposesAndRemovesSession()
{
using var manager = CreateManager(out var factory, out _);
var start = await manager.StartAsync("prod-api", "deploy", null, null, null, null, null, null, null, CancellationToken.None);
var start = await manager.StartAsync("prod-api", "deploy", null, null, null, null, null, null, CancellationToken.None);
var result = manager.Stop(start.SessionId!);
var writeAfterStop = manager.Write(start.SessionId!, "pwd\n");