Build initial MCP SSH server

This commit is contained in:
Vibe Myass
2026-05-24 20:45:12 +00:00
commit a8f7e8f483
28 changed files with 2116 additions and 0 deletions

View File

@@ -0,0 +1,12 @@
namespace McpSsh.Server.Audit;
public sealed record AuditEvent(
DateTimeOffset TimestampUtc,
string Tool,
string? Host,
string? Username,
string? Command,
bool Success,
long DurationMs,
string? ErrorCode = null,
string? Message = null);

View File

@@ -0,0 +1,6 @@
namespace McpSsh.Server.Audit;
public interface IAuditLogger
{
void Log(AuditEvent auditEvent);
}

View File

@@ -0,0 +1,61 @@
using System.Text.Json;
namespace McpSsh.Server.Audit;
public sealed class JsonLineAuditLogger : IAuditLogger
{
private static readonly JsonSerializerOptions SerializerOptions = new(JsonSerializerDefaults.Web);
private readonly TextWriter _writer;
public JsonLineAuditLogger()
: this(Console.Error)
{
}
public JsonLineAuditLogger(TextWriter writer)
{
_writer = writer;
}
public void Log(AuditEvent auditEvent)
{
ArgumentNullException.ThrowIfNull(auditEvent);
var safeEvent = auditEvent with
{
Command = Redact(auditEvent.Command),
Message = Redact(auditEvent.Message)
};
_writer.WriteLine(JsonSerializer.Serialize(safeEvent, SerializerOptions));
}
private static string? Redact(string? value)
{
if (string.IsNullOrWhiteSpace(value))
{
return value;
}
var redacted = value;
foreach (var marker in new[] { "password=", "passphrase=", "token=", "secret=" })
{
var index = redacted.IndexOf(marker, StringComparison.OrdinalIgnoreCase);
if (index < 0)
{
continue;
}
var valueStart = index + marker.Length;
var valueEnd = redacted.IndexOfAny([' ', '\t', '\r', '\n', '"', '\''], valueStart);
if (valueEnd < 0)
{
valueEnd = redacted.Length;
}
redacted = string.Concat(redacted.AsSpan(0, valueStart), "***", redacted.AsSpan(valueEnd));
}
return redacted;
}
}

View File

@@ -0,0 +1,18 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net10.0</TargetFramework>
<RuntimeIdentifiers>win-x64;linux-x64;linux-arm64;osx-x64;osx-arm64</RuntimeIdentifiers>
<AssemblyName>mcp-ssh</AssemblyName>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Hosting" Version="10.0.0" />
<PackageReference Include="ModelContextProtocol" Version="1.3.0" />
<PackageReference Include="SSH.NET" Version="2025.1.0" />
</ItemGroup>
</Project>

View File

@@ -0,0 +1,31 @@
using McpSsh.Server;
using McpSsh.Server.Audit;
using McpSsh.Server.Ssh;
using McpSsh.Server.Terminal;
using McpSsh.Server.Tools;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
var builder = Host.CreateApplicationBuilder(args);
builder.Logging.AddConsole(options =>
{
options.LogToStandardErrorThreshold = LogLevel.Trace;
});
builder.Services.AddSingleton<ISystemClock, SystemClock>();
builder.Services.AddSingleton<IFileSystem, LocalFileSystem>();
builder.Services.AddSingleton<ISshKeyResolver, DefaultSshKeyResolver>();
builder.Services.AddSingleton<IAuditLogger, JsonLineAuditLogger>();
builder.Services.AddSingleton<ISshCommandExecutor, SshNetCommandExecutor>();
builder.Services.AddSingleton<SshExecService>();
builder.Services.AddSingleton<ITerminalConnectionFactory, SshNetTerminalConnectionFactory>();
builder.Services.AddSingleton<TerminalSessionManager>();
builder.Services
.AddMcpServer()
.WithStdioServerTransport()
.WithTools<SshTools>();
await builder.Build().RunAsync();

View File

@@ -0,0 +1,91 @@
namespace McpSsh.Server.Ssh;
public interface ISshKeyResolver
{
string ResolveKeyPath(string? requestedKeyPath);
}
public interface IFileSystem
{
bool FileExists(string path);
}
public sealed class LocalFileSystem : IFileSystem
{
public bool FileExists(string path) => File.Exists(path);
}
public sealed class DefaultSshKeyResolver : ISshKeyResolver
{
private static readonly string[] DefaultKeyNames = ["id_ed25519", "id_ecdsa", "id_rsa"];
private readonly IFileSystem _fileSystem;
private readonly string _sshDirectory;
public DefaultSshKeyResolver(IFileSystem fileSystem)
: this(fileSystem, Path.Combine(GetHomeDirectory(), ".ssh"))
{
}
public DefaultSshKeyResolver(IFileSystem fileSystem, string sshDirectory)
{
_fileSystem = fileSystem;
_sshDirectory = sshDirectory;
}
public string ResolveKeyPath(string? requestedKeyPath)
{
if (!string.IsNullOrWhiteSpace(requestedKeyPath))
{
var expanded = ExpandHomeDirectory(requestedKeyPath.Trim());
if (_fileSystem.FileExists(expanded))
{
return expanded;
}
throw new SshToolException("ssh_key_not_found", $"SSH private key not found at '{expanded}'.");
}
return ResolveDefaultKeyPath();
}
private string ResolveDefaultKeyPath()
{
foreach (var keyName in DefaultKeyNames)
{
var path = Path.Combine(_sshDirectory, keyName);
if (_fileSystem.FileExists(path))
{
return path;
}
}
throw new SshToolException("ssh_key_not_found", $"No default SSH private key found in '{_sshDirectory}'.");
}
private static string ExpandHomeDirectory(string path)
{
if (path == "~")
{
return GetHomeDirectory();
}
if (path.StartsWith("~/", StringComparison.Ordinal))
{
return Path.Combine(GetHomeDirectory(), path[2..]);
}
return path;
}
private static string GetHomeDirectory()
{
var home = Environment.GetFolderPath(Environment.SpecialFolder.UserProfile);
if (!string.IsNullOrWhiteSpace(home))
{
return home;
}
return Environment.GetEnvironmentVariable("HOME")
?? throw new SshToolException("home_directory_not_found", "Unable to determine the current user's home directory.");
}
}

View File

@@ -0,0 +1,36 @@
using System.Text;
namespace McpSsh.Server.Ssh;
public static class RemoteShellCommand
{
public static string Build(string command, string? cwd)
{
if (string.IsNullOrWhiteSpace(cwd))
{
return command;
}
return $"cd {Quote(cwd)} && {command}";
}
public static string Quote(string value)
{
var builder = new StringBuilder(value.Length + 2);
builder.Append('\'');
foreach (var character in value)
{
if (character == '\'')
{
builder.Append("'\\''");
}
else
{
builder.Append(character);
}
}
builder.Append('\'');
return builder.ToString();
}
}

View File

@@ -0,0 +1,73 @@
using Renci.SshNet;
using Renci.SshNet.Common;
namespace McpSsh.Server.Ssh;
public interface ISshCommandExecutor
{
Task<SshExecResult> ExecuteAsync(SshExecRequest request, CancellationToken cancellationToken);
}
public sealed class SshNetCommandExecutor : ISshCommandExecutor
{
private readonly ISshKeyResolver _keyResolver;
public SshNetCommandExecutor(ISshKeyResolver keyResolver)
{
_keyResolver = keyResolver;
}
public Task<SshExecResult> ExecuteAsync(SshExecRequest request, CancellationToken cancellationToken)
{
return Task.Run(() => Execute(request, cancellationToken), cancellationToken);
}
private SshExecResult Execute(SshExecRequest request, CancellationToken cancellationToken)
{
try
{
var keyPath = _keyResolver.ResolveKeyPath(request.KeyPath);
using var keyFile = string.IsNullOrEmpty(request.KeyPassphrase)
? new PrivateKeyFile(keyPath)
: new PrivateKeyFile(keyPath, request.KeyPassphrase);
using var client = new SshClient(request.Host, request.Port, request.Username, keyFile);
client.ConnectionInfo.Timeout = TimeSpan.FromSeconds(request.TimeoutSeconds);
client.Connect();
cancellationToken.ThrowIfCancellationRequested();
var remoteCommand = RemoteShellCommand.Build(request.Command, request.Cwd);
using var command = client.CreateCommand(remoteCommand);
command.CommandTimeout = TimeSpan.FromSeconds(request.TimeoutSeconds);
var started = DateTimeOffset.UtcNow;
var stdout = command.Execute();
var duration = DateTimeOffset.UtcNow - started;
return new SshExecResult(
command.ExitStatus ?? -1,
stdout,
command.Error,
(long)duration.TotalMilliseconds,
TimedOut: false,
Error: null,
Message: null);
}
catch (SshAuthenticationException ex)
{
throw new SshToolException("ssh_authentication_failed", "SSH key authentication failed. Check the username, key path, and key passphrase.", ex);
}
catch (SshOperationTimeoutException ex)
{
throw new SshToolException("ssh_timeout", "SSH command timed out.", ex);
}
catch (TimeoutException ex)
{
throw new SshToolException("ssh_timeout", "SSH command timed out.", ex);
}
catch (SshException ex)
{
throw new SshToolException("ssh_error", ex.Message, ex);
}
}
}

View File

@@ -0,0 +1,11 @@
namespace McpSsh.Server.Ssh;
public sealed record SshExecRequest(
string Host,
string Username,
string Command,
int Port,
string? Cwd,
int TimeoutSeconds,
string? KeyPath,
string? KeyPassphrase);

View File

@@ -0,0 +1,10 @@
namespace McpSsh.Server.Ssh;
public sealed record SshExecResult(
int ExitCode,
string Stdout,
string Stderr,
long DurationMs,
bool TimedOut,
string? Error,
string? Message);

View File

@@ -0,0 +1,134 @@
using System.Diagnostics;
using McpSsh.Server.Audit;
namespace McpSsh.Server.Ssh;
public sealed class SshExecService
{
public const int DefaultPort = 22;
public const int DefaultTimeoutSeconds = 30;
public const int MaxTimeoutSeconds = 300;
private readonly ISshCommandExecutor _executor;
private readonly IAuditLogger _auditLogger;
private readonly ISystemClock _clock;
public SshExecService(
ISshCommandExecutor executor,
IAuditLogger auditLogger,
ISystemClock clock)
{
_executor = executor;
_auditLogger = auditLogger;
_clock = clock;
}
public async Task<SshExecResult> ExecuteAsync(
string host,
string username,
string command,
int? port,
string? cwd,
string? keyPath,
string? keyPassphrase,
int? timeoutSeconds,
CancellationToken cancellationToken)
{
var request = Validate(host, username, command, port, cwd, keyPath, keyPassphrase, timeoutSeconds);
var stopwatch = Stopwatch.StartNew();
try
{
using var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(request.TimeoutSeconds));
using var linked = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeout.Token);
var result = await _executor.ExecuteAsync(request, linked.Token).ConfigureAwait(false);
stopwatch.Stop();
Log(request, success: true, stopwatch.ElapsedMilliseconds);
return result with { DurationMs = result.DurationMs > 0 ? result.DurationMs : stopwatch.ElapsedMilliseconds };
}
catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested)
{
stopwatch.Stop();
Log(request, success: false, stopwatch.ElapsedMilliseconds, "ssh_timeout", "SSH command timed out.");
return new SshExecResult(
ExitCode: -1,
Stdout: string.Empty,
Stderr: string.Empty,
DurationMs: stopwatch.ElapsedMilliseconds,
TimedOut: true,
Error: "ssh_timeout",
Message: "SSH command timed out.");
}
catch (SshToolException ex)
{
stopwatch.Stop();
Log(request, success: false, stopwatch.ElapsedMilliseconds, ex.ErrorCode, ex.Message);
return new SshExecResult(
ExitCode: -1,
Stdout: string.Empty,
Stderr: string.Empty,
DurationMs: stopwatch.ElapsedMilliseconds,
TimedOut: ex.ErrorCode == "ssh_timeout",
Error: ex.ErrorCode,
Message: ex.Message);
}
}
private static SshExecRequest Validate(string host, string username, string command, int? port, string? cwd, string? keyPath, string? keyPassphrase, int? timeoutSeconds)
{
if (string.IsNullOrWhiteSpace(host))
{
throw new SshToolException("invalid_host", "Host is required.");
}
if (string.IsNullOrWhiteSpace(username))
{
throw new SshToolException("invalid_username", "Username is required.");
}
if (string.IsNullOrWhiteSpace(command))
{
throw new SshToolException("invalid_command", "Command is required.");
}
var resolvedPort = port ?? DefaultPort;
if (resolvedPort is < 1 or > 65535)
{
throw new SshToolException("invalid_port", "Port must be between 1 and 65535.");
}
var resolvedTimeout = timeoutSeconds ?? DefaultTimeoutSeconds;
if (resolvedTimeout is < 1 or > MaxTimeoutSeconds)
{
throw new SshToolException("invalid_timeout", $"Timeout must be between 1 and {MaxTimeoutSeconds} seconds.");
}
return new SshExecRequest(
host.Trim(),
username.Trim(),
command,
resolvedPort,
string.IsNullOrWhiteSpace(cwd) ? null : cwd,
resolvedTimeout,
string.IsNullOrWhiteSpace(keyPath) ? null : keyPath.Trim(),
string.IsNullOrEmpty(keyPassphrase) ? null : keyPassphrase);
}
private void Log(SshExecRequest request, bool success, long durationMs, string? errorCode = null, string? message = null)
{
_auditLogger.Log(new AuditEvent(
_clock.UtcNow,
"ssh_exec",
request.Host,
request.Username,
request.Command,
success,
durationMs,
errorCode,
message));
}
}

View File

@@ -0,0 +1,18 @@
namespace McpSsh.Server.Ssh;
public sealed class SshToolException : Exception
{
public SshToolException(string errorCode, string message)
: base(message)
{
ErrorCode = errorCode;
}
public SshToolException(string errorCode, string message, Exception innerException)
: base(message, innerException)
{
ErrorCode = errorCode;
}
public string ErrorCode { get; }
}

View File

@@ -0,0 +1,11 @@
namespace McpSsh.Server;
public interface ISystemClock
{
DateTimeOffset UtcNow { get; }
}
public sealed class SystemClock : ISystemClock
{
public DateTimeOffset UtcNow => DateTimeOffset.UtcNow;
}

View File

@@ -0,0 +1,95 @@
using Renci.SshNet;
using Renci.SshNet.Common;
using McpSsh.Server.Ssh;
namespace McpSsh.Server.Terminal;
public interface ITerminalConnection : IDisposable
{
bool DataAvailable { get; }
string ReadAvailable();
void Write(string input);
}
public interface ITerminalConnectionFactory
{
ITerminalConnection Create(TerminalStartRequest request);
}
public sealed class SshNetTerminalConnectionFactory : ITerminalConnectionFactory
{
private const int ShellBufferSize = 16 * 1024;
private readonly ISshKeyResolver _keyResolver;
public SshNetTerminalConnectionFactory(ISshKeyResolver keyResolver)
{
_keyResolver = keyResolver;
}
public ITerminalConnection Create(TerminalStartRequest request)
{
try
{
var keyPath = _keyResolver.ResolveKeyPath(request.KeyPath);
var keyFile = string.IsNullOrEmpty(request.KeyPassphrase)
? new PrivateKeyFile(keyPath)
: new PrivateKeyFile(keyPath, request.KeyPassphrase);
var client = new SshClient(request.Host, request.Port, request.Username, keyFile);
client.Connect();
var stream = client.CreateShellStream(
"xterm-256color",
(uint)request.Cols,
(uint)request.Rows,
width: 0,
height: 0,
bufferSize: ShellBufferSize);
return new SshNetTerminalConnection(client, stream, keyFile);
}
catch (SshAuthenticationException ex)
{
throw new SshToolException("ssh_authentication_failed", "SSH key authentication failed. Check the username, key path, and key passphrase.", ex);
}
catch (SshException ex)
{
throw new SshToolException("ssh_error", ex.Message, ex);
}
}
private sealed class SshNetTerminalConnection : ITerminalConnection
{
private readonly SshClient _client;
private readonly ShellStream _stream;
private readonly PrivateKeyFile _keyFile;
public SshNetTerminalConnection(SshClient client, ShellStream stream, PrivateKeyFile keyFile)
{
_client = client;
_stream = stream;
_keyFile = keyFile;
}
public bool DataAvailable => _stream.DataAvailable;
public string ReadAvailable()
{
return _stream.Read();
}
public void Write(string input)
{
_stream.Write(input);
}
public void Dispose()
{
_stream.Dispose();
_client.Dispose();
_keyFile.Dispose();
}
}
}

View File

@@ -0,0 +1,33 @@
namespace McpSsh.Server.Terminal;
public sealed record TerminalStartRequest(
string Host,
string Username,
string? Shell,
int Cols,
int Rows,
int Port,
int IdleTimeoutSeconds,
string? KeyPath,
string? KeyPassphrase);
public sealed record TerminalStartResult(
string? SessionId,
string? Error = null,
string? Message = null);
public sealed record TerminalWriteResult(
bool Accepted,
string? Error = null,
string? Message = null);
public sealed record TerminalReadResult(
string Output,
bool Truncated,
string? Error = null,
string? Message = null);
public sealed record TerminalStopResult(
bool Stopped,
string? Error = null,
string? Message = null);

View File

@@ -0,0 +1,347 @@
using System.Collections.Concurrent;
using System.Security.Cryptography;
using System.Text;
using McpSsh.Server.Audit;
using McpSsh.Server.Ssh;
namespace McpSsh.Server.Terminal;
public sealed class TerminalSessionManager : IDisposable
{
public const int DefaultPort = 22;
public const int DefaultCols = 120;
public const int DefaultRows = 40;
public const int DefaultIdleTimeoutSeconds = 900;
public const int MaxIdleTimeoutSeconds = 86_400;
public const int DefaultReadMaxBytes = 12_000;
public const int MaxReadBytes = 1_000_000;
private const int MaxBufferCharacters = 1_000_000;
private readonly ConcurrentDictionary<string, TerminalSession> _sessions = new(StringComparer.Ordinal);
private readonly ITerminalConnectionFactory _connectionFactory;
private readonly IAuditLogger _auditLogger;
private readonly ISystemClock _clock;
private readonly Timer _cleanupTimer;
public TerminalSessionManager(ITerminalConnectionFactory connectionFactory, IAuditLogger auditLogger, ISystemClock clock)
{
_connectionFactory = connectionFactory;
_auditLogger = auditLogger;
_clock = clock;
_cleanupTimer = new Timer(_ => CleanupIdleSessions(), null, TimeSpan.FromMinutes(1), TimeSpan.FromMinutes(1));
}
public Task<TerminalStartResult> StartAsync(
string host,
string username,
string? shell,
int? cols,
int? rows,
int? port,
string? keyPath,
string? keyPassphrase,
int? idleTimeoutSeconds,
CancellationToken cancellationToken)
{
var request = ValidateStart(host, username, shell, cols, rows, port, keyPath, keyPassphrase, idleTimeoutSeconds);
var started = _clock.UtcNow;
try
{
cancellationToken.ThrowIfCancellationRequested();
var connection = _connectionFactory.Create(request);
var sessionId = CreateSessionId();
var session = new TerminalSession(sessionId, request.Host, started, request.IdleTimeoutSeconds, connection);
if (!_sessions.TryAdd(sessionId, session))
{
connection.Dispose();
throw new SshToolException("terminal_session_conflict", "Unable to allocate a unique terminal session ID.");
}
if (!string.IsNullOrWhiteSpace(request.Shell))
{
connection.Write($"exec {request.Shell}\n");
}
session.ReaderTask = Task.Run(() => ReadLoopAsync(session), CancellationToken.None);
Log("terminal_start", request.Host, success: true, command: request.Shell);
return Task.FromResult(new TerminalStartResult(sessionId));
}
catch (SshToolException ex)
{
Log("terminal_start", request.Host, success: false, errorCode: ex.ErrorCode, message: ex.Message, command: request.Shell);
return Task.FromResult(new TerminalStartResult(null, ex.ErrorCode, ex.Message));
}
}
public TerminalWriteResult Write(string sessionId, string input)
{
if (string.IsNullOrWhiteSpace(sessionId))
{
return new TerminalWriteResult(false, "invalid_session_id", "Session ID is required.");
}
if (input is null)
{
return new TerminalWriteResult(false, "invalid_input", "Input is required.");
}
if (!_sessions.TryGetValue(sessionId, out var session))
{
return new TerminalWriteResult(false, "terminal_session_not_found", $"Terminal session '{sessionId}' was not found.");
}
try
{
session.Touch(_clock.UtcNow);
session.Connection.Write(input);
Log("terminal_write", session.Host, success: true, command: input);
return new TerminalWriteResult(true);
}
catch (ObjectDisposedException ex)
{
RemoveSession(sessionId);
Log("terminal_write", session.Host, success: false, errorCode: "terminal_session_closed", message: ex.Message, command: input);
return new TerminalWriteResult(false, "terminal_session_closed", "Terminal session is closed.");
}
}
public TerminalReadResult Read(string sessionId, int? maxBytes)
{
if (string.IsNullOrWhiteSpace(sessionId))
{
return new TerminalReadResult(string.Empty, false, "invalid_session_id", "Session ID is required.");
}
if (!_sessions.TryGetValue(sessionId, out var session))
{
return new TerminalReadResult(string.Empty, false, "terminal_session_not_found", $"Terminal session '{sessionId}' was not found.");
}
var byteLimit = ValidateMaxBytes(maxBytes);
session.Touch(_clock.UtcNow);
var (output, truncated) = session.Drain(byteLimit);
Log("terminal_read", session.Host, success: true);
return new TerminalReadResult(output, truncated);
}
public TerminalStopResult Stop(string sessionId)
{
if (string.IsNullOrWhiteSpace(sessionId))
{
return new TerminalStopResult(false, "invalid_session_id", "Session ID is required.");
}
if (!_sessions.TryRemove(sessionId, out var session))
{
return new TerminalStopResult(false, "terminal_session_not_found", $"Terminal session '{sessionId}' was not found.");
}
session.Dispose();
Log("terminal_stop", session.Host, success: true);
return new TerminalStopResult(true);
}
public void Dispose()
{
_cleanupTimer.Dispose();
foreach (var sessionId in _sessions.Keys)
{
RemoveSession(sessionId);
}
}
private async Task ReadLoopAsync(TerminalSession session)
{
while (!session.Cancellation.IsCancellationRequested)
{
try
{
if (session.Connection.DataAvailable)
{
var output = session.Connection.ReadAvailable();
if (!string.IsNullOrEmpty(output))
{
session.Append(output);
}
}
else
{
await Task.Delay(50, session.Cancellation).ConfigureAwait(false);
}
}
catch (OperationCanceledException)
{
return;
}
catch (ObjectDisposedException)
{
return;
}
}
}
private static TerminalStartRequest ValidateStart(string host, string username, string? shell, int? cols, int? rows, int? port, string? keyPath, string? keyPassphrase, int? idleTimeoutSeconds)
{
if (string.IsNullOrWhiteSpace(host))
{
throw new SshToolException("invalid_host", "Host is required.");
}
if (string.IsNullOrWhiteSpace(username))
{
throw new SshToolException("invalid_username", "Username is required.");
}
var resolvedPort = port ?? DefaultPort;
if (resolvedPort is < 1 or > 65535)
{
throw new SshToolException("invalid_port", "Port must be between 1 and 65535.");
}
var resolvedCols = cols ?? DefaultCols;
if (resolvedCols is < 20 or > 500)
{
throw new SshToolException("invalid_cols", "Terminal columns must be between 20 and 500.");
}
var resolvedRows = rows ?? DefaultRows;
if (resolvedRows is < 5 or > 500)
{
throw new SshToolException("invalid_rows", "Terminal rows must be between 5 and 500.");
}
var resolvedIdleTimeout = idleTimeoutSeconds ?? DefaultIdleTimeoutSeconds;
if (resolvedIdleTimeout is < 1 or > MaxIdleTimeoutSeconds)
{
throw new SshToolException("invalid_idle_timeout", $"Idle timeout must be between 1 and {MaxIdleTimeoutSeconds} seconds.");
}
return new TerminalStartRequest(
host.Trim(),
username.Trim(),
string.IsNullOrWhiteSpace(shell) ? null : shell.Trim(),
resolvedCols,
resolvedRows,
resolvedPort,
resolvedIdleTimeout,
string.IsNullOrWhiteSpace(keyPath) ? null : keyPath.Trim(),
string.IsNullOrEmpty(keyPassphrase) ? null : keyPassphrase);
}
private static int ValidateMaxBytes(int? maxBytes)
{
var resolved = maxBytes ?? DefaultReadMaxBytes;
if (resolved is < 1 or > MaxReadBytes)
{
throw new SshToolException("invalid_max_bytes", $"maxBytes must be between 1 and {MaxReadBytes}.");
}
return resolved;
}
private void CleanupIdleSessions()
{
var now = _clock.UtcNow;
foreach (var (sessionId, session) in _sessions)
{
if (now - session.LastActivityUtc >= TimeSpan.FromSeconds(session.IdleTimeoutSeconds))
{
RemoveSession(sessionId);
Log("terminal_idle_cleanup", session.Host, success: true);
}
}
}
private void RemoveSession(string sessionId)
{
if (_sessions.TryRemove(sessionId, out var session))
{
session.Dispose();
}
}
private void Log(string tool, string? host, bool success, string? errorCode = null, string? message = null, string? command = null)
{
_auditLogger.Log(new AuditEvent(_clock.UtcNow, tool, host, null, command, success, 0, errorCode, message));
}
private static string CreateSessionId()
{
Span<byte> bytes = stackalloc byte[12];
RandomNumberGenerator.Fill(bytes);
return $"term_{Convert.ToHexString(bytes).ToLowerInvariant()}";
}
private sealed class TerminalSession : IDisposable
{
private readonly object _gate = new();
private readonly StringBuilder _buffer = new();
private readonly CancellationTokenSource _cancellation = new();
public TerminalSession(string sessionId, string host, DateTimeOffset now, int idleTimeoutSeconds, ITerminalConnection connection)
{
SessionId = sessionId;
Host = host;
CreatedUtc = now;
LastActivityUtc = now;
IdleTimeoutSeconds = idleTimeoutSeconds;
Connection = connection;
Cancellation = _cancellation.Token;
}
public string SessionId { get; }
public string Host { get; }
public DateTimeOffset CreatedUtc { get; }
public DateTimeOffset LastActivityUtc { get; private set; }
public int IdleTimeoutSeconds { get; }
public ITerminalConnection Connection { get; }
public CancellationToken Cancellation { get; }
public Task? ReaderTask { get; set; }
public void Touch(DateTimeOffset now)
{
lock (_gate)
{
LastActivityUtc = now;
}
}
public void Append(string output)
{
lock (_gate)
{
_buffer.Append(output);
if (_buffer.Length > MaxBufferCharacters)
{
_buffer.Remove(0, _buffer.Length - MaxBufferCharacters);
}
}
}
public (string Output, bool Truncated) Drain(int maxCharacters)
{
lock (_gate)
{
if (_buffer.Length == 0)
{
return (string.Empty, false);
}
var count = Math.Min(maxCharacters, _buffer.Length);
var output = _buffer.ToString(0, count);
_buffer.Remove(0, count);
return (output, _buffer.Length > 0);
}
}
public void Dispose()
{
_cancellation.Cancel();
Connection.Dispose();
_cancellation.Dispose();
}
}
}

View File

@@ -0,0 +1,78 @@
using System.ComponentModel;
using McpSsh.Server.Ssh;
using McpSsh.Server.Terminal;
using ModelContextProtocol.Server;
namespace McpSsh.Server.Tools;
[McpServerToolType]
public sealed class SshTools
{
private readonly SshExecService _sshExecService;
private readonly TerminalSessionManager _terminalSessionManager;
public SshTools(SshExecService sshExecService, TerminalSessionManager terminalSessionManager)
{
_sshExecService = sshExecService;
_terminalSessionManager = terminalSessionManager;
}
[McpServerTool(Name = "ssh_exec", Destructive = true)]
[Description("Execute a single command over SSH using key-based authentication.")]
public Task<SshExecResult> ExecuteAsync(
[Description("Remote hostname or IP address. OpenSSH aliases are not supported in this vertical slice.")] string host,
[Description("Remote SSH username.")] string username,
[Description("Command to execute on the remote host.")] string command,
[Description("Remote SSH port. Defaults to 22.")] int? port = null,
[Description("Optional remote working directory.")] string? cwd = null,
[Description("Optional local private key path. Defaults to ~/.ssh/id_ed25519, ~/.ssh/id_ecdsa, then ~/.ssh/id_rsa.")] string? keyPath = null,
[Description("Optional private key passphrase. Use only with trusted MCP clients.")] string? keyPassphrase = null,
[Description("Timeout in seconds. Defaults to 30 and is capped at 300.")] int? timeoutSeconds = null,
CancellationToken cancellationToken = default)
{
return _sshExecService.ExecuteAsync(host, username, command, port, cwd, keyPath, keyPassphrase, timeoutSeconds, cancellationToken);
}
[McpServerTool(Name = "terminal_start", Destructive = true)]
[Description("Start a persistent SSH PTY shell session using key-based authentication.")]
public Task<TerminalStartResult> StartTerminalAsync(
[Description("Remote hostname or IP address. OpenSSH aliases are not supported in this vertical slice.")] string host,
[Description("Remote SSH username.")] string username,
[Description("Optional shell to exec after PTY allocation, for example bash or sh.")] string? shell = null,
[Description("Terminal columns. Defaults to 120.")] int? cols = null,
[Description("Terminal rows. Defaults to 40.")] int? rows = null,
[Description("Remote SSH port. Defaults to 22.")] int? port = null,
[Description("Optional local private key path. Defaults to ~/.ssh/id_ed25519, ~/.ssh/id_ecdsa, then ~/.ssh/id_rsa.")] string? keyPath = null,
[Description("Optional private key passphrase. Use only with trusted MCP clients.")] string? keyPassphrase = null,
[Description("Idle timeout in seconds. Defaults to 900.")] int? idleTimeoutSeconds = null,
CancellationToken cancellationToken = default)
{
return _terminalSessionManager.StartAsync(host, username, shell, cols, rows, port, keyPath, keyPassphrase, idleTimeoutSeconds, cancellationToken);
}
[McpServerTool(Name = "terminal_write", Destructive = true)]
[Description("Write input to an active SSH terminal session.")]
public TerminalWriteResult WriteTerminal(
[Description("Terminal session ID returned by terminal_start.")] string sessionId,
[Description("Input to write to the terminal. Include newline characters when submitting commands.")] string input)
{
return _terminalSessionManager.Write(sessionId, input);
}
[McpServerTool(Name = "terminal_read", Destructive = false)]
[Description("Read buffered output from an active SSH terminal session.")]
public TerminalReadResult ReadTerminal(
[Description("Terminal session ID returned by terminal_start.")] string sessionId,
[Description("Maximum output characters to read. Defaults to 12000.")] int? maxBytes = null)
{
return _terminalSessionManager.Read(sessionId, maxBytes);
}
[McpServerTool(Name = "terminal_stop", Destructive = true)]
[Description("Stop and remove an SSH terminal session.")]
public TerminalStopResult StopTerminal(
[Description("Terminal session ID returned by terminal_start.")] string sessionId)
{
return _terminalSessionManager.Stop(sessionId);
}
}