commit a8f7e8f483580839f63734dba877c74a09efdf8e Author: Vibe Myass Date: Sun May 24 20:45:12 2026 +0000 Build initial MCP SSH server diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d5e3c41 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +bin/ +obj/ +TestResults/ +artifacts/ +*.user +*.suo +.vs/ +.vscode/ diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..f37f7d6 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,573 @@ +# AGENTS.md + +## Project + +MCP SSH Server + +Self-contained MCP server binary that provides SSH command execution, persistent terminal sessions, and file transfer support with minimal user configuration. + +The implementation must rely on the user's existing OpenSSH configuration and SSH environment. + +--- + +# Core Product Requirements + +## Long-Term Goals + +* SSH command execution +* Persistent interactive terminal sessions +* Remote file transfer support +* Minimal user-side MCP configuration +* Self-contained binary distribution +* Use existing `~/.ssh/config` +* Use existing SSH keys and SSH agent +* Support SSH aliases exactly as users use them in terminal + +## Current Vertical Slice + +The first implementation pass intentionally narrows the MVP: + +* .NET 10 stdio MCP server +* `ssh_exec` +* `terminal_start`, `terminal_write`, `terminal_read`, and `terminal_stop` +* SSH.NET-based SSH connections +* Tool-supplied `host`, `username`, optional `port`, optional `keyPath`, and optional `keyPassphrase` +* Default key discovery from `~/.ssh/id_ed25519`, `~/.ssh/id_ecdsa`, then `~/.ssh/id_rsa` when `keyPath` is omitted +* No SSH agent, OpenSSH alias, `ssh -G`, ProxyJump, ProxyCommand, or SFTP support yet +* Basic audit logging and timeout enforcement + +## Non-Goals + +* Do not replace OpenSSH +* Do not store SSH private keys +* Do not expose arbitrary local shell execution +* Do not require duplicate SSH configuration inside MCP settings +* Do not build a web UI for MVP + +--- + +# Distribution Requirements + +The server must ship as a single self-contained executable. + +Target platforms: + +* Windows x64 +* Linux x64 +* Linux arm64 +* macOS x64 +* macOS arm64 + +Expected binary names: + +```text +mcp-ssh +mcp-ssh.exe +``` + +Example .NET publish command: + +```bash +dotnet publish \ + -c Release \ + -r linux-x64 \ + --self-contained true \ + /p:PublishSingleFile=true \ + /p:PublishTrimmed=true +``` + +--- + +# MCP Transport + +Use stdio transport for MVP. + +Example MCP config: + +```json +{ + "mcpServers": { + "ssh": { + "command": "/path/to/mcp-ssh", + "args": [] + } + } +} +``` + +No SSH-specific configuration should be required in MCP config. + +--- + +# SSH Configuration Resolution + +The long-term implementation should use OpenSSH configuration resolution. + +Primary mechanism: + +```bash +ssh -G +``` + +The implementation should parse the resolved output. + +Required resolved fields: + +* hostname +* user +* port +* identityfile +* proxyjump +* proxycommand +* stricthostkeychecking +* userknownhostsfile +* identitiesonly + +The implementation must support: + +* SSH aliases +* bastion hosts +* proxy jump +* existing SSH agent +* known hosts validation + +Fallback parsing of `~/.ssh/config` may be implemented if `ssh -G` is unavailable. + +This is not part of the current vertical slice. The current implementation treats `host` as a hostname or IP address supplied directly to the MCP tool. + +--- + +# Authentication Requirements + +Authentication should reuse the user's existing SSH environment. + +Supported: + +* SSH agent +* identity files +* OpenSSH config +* known hosts +* optional password prompt support + +The implementation must never persist private keys. + +--- + +# File Transfer Strategy + +SFTP is the primary transfer mechanism. + +Rationale: + +* More stable across SSH implementations +* Better metadata support +* Better directory traversal semantics +* Avoids SCP shell parsing issues +* Modern OpenSSH increasingly routes SCP over SFTP internally + +Priority order: + +```text +1. SFTP +2. SCP fallback (optional future enhancement) +``` + +The implementation must: + +* Validate SFTP subsystem availability during connection +* Return structured capability errors if unavailable +* Not silently downgrade to SCP + +Example error: + +```json +{ + "error": "sftp_unavailable", + "message": "Remote host does not expose the SFTP subsystem.", + "scpFallbackAvailable": false +} +``` + +--- + +# MCP Tools + +## ssh_exec + +Execute a single SSH command. + +Input: + +```json +{ + "host": "prod-api.example.com", + "username": "deploy", + "command": "systemctl status nginx", + "cwd": "/var/www", + "port": 22, + "keyPath": "~/.ssh/id_ed25519", + "keyPassphrase": "optional-passphrase", + "timeoutSeconds": 30 +} +``` + +Output: + +```json +{ + "exitCode": 0, + "stdout": "...", + "stderr": "...", + "durationMs": 421 +} +``` + +Requirements: + +* Enforce timeout +* Capture stdout separately +* Capture stderr separately +* Preserve non-zero exit codes +* Authenticate with explicit `keyPath`, or the first available default private key from `~/.ssh/id_ed25519`, `~/.ssh/id_ecdsa`, then `~/.ssh/id_rsa` +* Support optional `keyPassphrase` for encrypted private keys + +--- + +## terminal_start + +Start a persistent PTY shell session. + +Input: + +```json +{ + "host": "prod-api.example.com", + "username": "deploy", + "shell": "bash", + "cols": 120, + "rows": 40, + "port": 22, + "keyPath": "~/.ssh/id_ed25519", + "keyPassphrase": "optional-passphrase", + "idleTimeoutSeconds": 900 +} +``` + +Output: + +```json +{ + "sessionId": "term_abc123" +} +``` + +Requirements: + +* Allocate PTY +* Maintain server-side session state +* Support idle timeout cleanup +* Use the same key-auth inputs and default key discovery as `ssh_exec` + +--- + +## terminal_write + +Write to an active terminal session. + +Input: + +```json +{ + "sessionId": "term_abc123", + "input": "tail -f /var/log/nginx/error.log\\n" +} +``` + +Output: + +```json +{ + "accepted": true +} +``` + +--- + +## terminal_read + +Read buffered output from a terminal session. + +Input: + +```json +{ + "sessionId": "term_abc123", + "maxBytes": 12000 +} +``` + +Output: + +```json +{ + "output": "...", + "truncated": false +} +``` + +--- + +## terminal_stop + +Stop and remove a terminal session. + +Input: + +```json +{ + "sessionId": "term_abc123" +} +``` + +Output: + +```json +{ + "stopped": true +} +``` + +--- + +## sftp_list + +List remote directory contents. + +Input: + +```json +{ + "host": "prod-api", + "remotePath": "/var/www" +} +``` + +Output: + +```json +{ + "entries": [ + { + "name": "app", + "path": "/var/www/app", + "type": "directory", + "size": 4096, + "modifiedUtc": "2026-05-24T12:00:00Z" + } + ] +} +``` + +--- + +## sftp_get + +Download a remote file. + +Input: + +```json +{ + "host": "prod-api", + "remotePath": "/var/log/app.log", + "localPath": "./downloads/app.log" +} +``` + +Output: + +```json +{ + "bytesTransferred": 123456, + "localPath": "./downloads/app.log" +} +``` + +Requirements: + +* Enforce max download size +* Prevent unsafe local path traversal +* Fail on overwrite unless explicitly enabled + +--- + +## sftp_put + +Upload a local file. + +Input: + +```json +{ + "host": "prod-api", + "localPath": "./dist/app.tar.gz", + "remotePath": "/tmp/app.tar.gz", + "overwrite": false +} +``` + +Output: + +```json +{ + "bytesTransferred": 123456, + "remotePath": "/tmp/app.tar.gz" +} +``` + +Requirements: + +* Enforce max upload size +* Fail if remote file exists and overwrite is false + +--- + +# Session Management + +Terminal sessions should be maintained in memory. + +Example session state: + +```json +{ + "sessionId": "term_abc123", + "host": "prod-api", + "createdUtc": "...", + "lastActivityUtc": "...", + "idleTimeoutSeconds": 900 +} +``` + +Requirements: + +* Cryptographically random session IDs +* Idle timeout cleanup +* Graceful cleanup on shutdown +* Output buffering with maximum limits + +--- + +# Security Requirements + +The implementation must default to safe behavior. + +Required safeguards: + +* Command timeout enforcement +* Upload size limits +* Download size limits +* Audit logging +* No private key persistence +* No arbitrary local command execution +* No command-content or host blocking; access control is delegated to SSH users, SSH keys, and remote-side permissions + +--- + +# Audit Logging + +Every tool call should be logged. + +Example: + +```json +{ + "timestampUtc": "2026-05-24T12:00:00Z", + "tool": "ssh_exec", + "host": "prod-api", + "command": "systemctl status nginx", + "success": true, + "durationMs": 421 +} +``` + +Sensitive values must be redacted. + +--- + +# Recommended Implementation Stack + +Language: + +* C# +* .NET 10+ + +Recommended libraries: + +* Official MCP SDK +* SSH.NET +* OpenSSH `ssh -G` + +--- + +# Recommended Internal Interfaces + +```csharp +public interface ISshConfigResolver +{ + /// + /// Resolves an SSH host alias into effective connection settings. + /// + /// The SSH host alias. + /// The resolved SSH configuration. + ResolvedSshConfig Resolve(string hostAlias); +} + +public interface ISshSessionFactory +{ + /// + /// Creates a connected SSH client. + /// + /// The SSH host alias. + /// A connected SSH client. + SshClient CreateSshClient(string hostAlias); + + /// + /// Creates a connected SFTP client. + /// + /// The SSH host alias. + /// A connected SFTP client. + SftpClient CreateSftpClient(string hostAlias); +} +``` + +--- + +# MVP Scope + +The MVP must include: + +* Self-contained binary +* stdio transport +* ssh_exec +* terminal_start +* terminal_write +* terminal_read +* terminal_stop +* basic audit logging + +--- + +# Future Enhancements + +Potential future work: + +* Streamable HTTP transport +* SCP compatibility fallback +* Per-host policy configuration +* File checksum validation +* Directory upload/download +* Remote command templates +* Resource-based remote file browsing +* Multi-hop SSH validation +* Secret redaction improvements +* Per-tool and per-host authorization policy diff --git a/McpSsh.slnx b/McpSsh.slnx new file mode 100644 index 0000000..11f46c6 --- /dev/null +++ b/McpSsh.slnx @@ -0,0 +1,8 @@ + + + + + + + + diff --git a/README.md b/README.md new file mode 100644 index 0000000..7a62a0f --- /dev/null +++ b/README.md @@ -0,0 +1,34 @@ +# MCP SSH Server + +Self-contained MCP server for SSH command execution and persistent terminal sessions. + +## Build + +```bash +dotnet build McpSsh.slnx +dotnet test McpSsh.slnx --no-build +``` + +## Publish + +Publish all supported runtime IDs as single-file, self-contained binaries: + +```bash +./scripts/publish.sh +``` + +Publish one runtime: + +```bash +./scripts/publish.sh linux-x64 +``` + +The script writes binaries under `artifacts/publish//`. + +Trimming is disabled by default because the MCP SDK discovers tools through reflection. To experiment with trimming after validating tool discovery in the published binary: + +```bash +PUBLISH_TRIMMED=true ./scripts/publish.sh linux-x64 +``` + +NativeAOT is not enabled. This code should be treated as not AOT-ready until the MCP SDK reflection path and SSH.NET dependencies are explicitly validated under `PublishAot=true`. diff --git a/scripts/publish.sh b/scripts/publish.sh new file mode 100755 index 0000000..19f7012 --- /dev/null +++ b/scripts/publish.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +PROJECT="$ROOT_DIR/src/McpSsh.Server/McpSsh.Server.csproj" +CONFIGURATION="${CONFIGURATION:-Release}" +PUBLISH_TRIMMED="${PUBLISH_TRIMMED:-false}" + +if [[ $# -gt 0 ]]; then + RIDS=("$@") +else + RIDS=(win-x64 linux-x64 linux-arm64 osx-x64 osx-arm64) +fi + +for rid in "${RIDS[@]}"; do + output="$ROOT_DIR/artifacts/publish/$rid" + dotnet publish "$PROJECT" \ + -c "$CONFIGURATION" \ + -r "$rid" \ + --self-contained true \ + -o "$output" \ + -p:PublishSingleFile=true \ + -p:PublishTrimmed="$PUBLISH_TRIMMED" \ + -p:EnableCompressionInSingleFile=true +done diff --git a/src/McpSsh.Server/Audit/AuditEvent.cs b/src/McpSsh.Server/Audit/AuditEvent.cs new file mode 100644 index 0000000..3504ffe --- /dev/null +++ b/src/McpSsh.Server/Audit/AuditEvent.cs @@ -0,0 +1,12 @@ +namespace McpSsh.Server.Audit; + +public sealed record AuditEvent( + DateTimeOffset TimestampUtc, + string Tool, + string? Host, + string? Username, + string? Command, + bool Success, + long DurationMs, + string? ErrorCode = null, + string? Message = null); diff --git a/src/McpSsh.Server/Audit/IAuditLogger.cs b/src/McpSsh.Server/Audit/IAuditLogger.cs new file mode 100644 index 0000000..35ccf95 --- /dev/null +++ b/src/McpSsh.Server/Audit/IAuditLogger.cs @@ -0,0 +1,6 @@ +namespace McpSsh.Server.Audit; + +public interface IAuditLogger +{ + void Log(AuditEvent auditEvent); +} diff --git a/src/McpSsh.Server/Audit/JsonLineAuditLogger.cs b/src/McpSsh.Server/Audit/JsonLineAuditLogger.cs new file mode 100644 index 0000000..3eda446 --- /dev/null +++ b/src/McpSsh.Server/Audit/JsonLineAuditLogger.cs @@ -0,0 +1,61 @@ +using System.Text.Json; + +namespace McpSsh.Server.Audit; + +public sealed class JsonLineAuditLogger : IAuditLogger +{ + private static readonly JsonSerializerOptions SerializerOptions = new(JsonSerializerDefaults.Web); + private readonly TextWriter _writer; + + public JsonLineAuditLogger() + : this(Console.Error) + { + } + + public JsonLineAuditLogger(TextWriter writer) + { + _writer = writer; + } + + public void Log(AuditEvent auditEvent) + { + ArgumentNullException.ThrowIfNull(auditEvent); + + var safeEvent = auditEvent with + { + Command = Redact(auditEvent.Command), + Message = Redact(auditEvent.Message) + }; + + _writer.WriteLine(JsonSerializer.Serialize(safeEvent, SerializerOptions)); + } + + private static string? Redact(string? value) + { + if (string.IsNullOrWhiteSpace(value)) + { + return value; + } + + var redacted = value; + foreach (var marker in new[] { "password=", "passphrase=", "token=", "secret=" }) + { + var index = redacted.IndexOf(marker, StringComparison.OrdinalIgnoreCase); + if (index < 0) + { + continue; + } + + var valueStart = index + marker.Length; + var valueEnd = redacted.IndexOfAny([' ', '\t', '\r', '\n', '"', '\''], valueStart); + if (valueEnd < 0) + { + valueEnd = redacted.Length; + } + + redacted = string.Concat(redacted.AsSpan(0, valueStart), "***", redacted.AsSpan(valueEnd)); + } + + return redacted; + } +} diff --git a/src/McpSsh.Server/McpSsh.Server.csproj b/src/McpSsh.Server/McpSsh.Server.csproj new file mode 100644 index 0000000..58305dc --- /dev/null +++ b/src/McpSsh.Server/McpSsh.Server.csproj @@ -0,0 +1,18 @@ + + + + Exe + net10.0 + win-x64;linux-x64;linux-arm64;osx-x64;osx-arm64 + mcp-ssh + enable + enable + + + + + + + + + diff --git a/src/McpSsh.Server/Program.cs b/src/McpSsh.Server/Program.cs new file mode 100644 index 0000000..7991abb --- /dev/null +++ b/src/McpSsh.Server/Program.cs @@ -0,0 +1,31 @@ +using McpSsh.Server; +using McpSsh.Server.Audit; +using McpSsh.Server.Ssh; +using McpSsh.Server.Terminal; +using McpSsh.Server.Tools; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; + +var builder = Host.CreateApplicationBuilder(args); + +builder.Logging.AddConsole(options => +{ + options.LogToStandardErrorThreshold = LogLevel.Trace; +}); + +builder.Services.AddSingleton(); +builder.Services.AddSingleton(); +builder.Services.AddSingleton(); +builder.Services.AddSingleton(); +builder.Services.AddSingleton(); +builder.Services.AddSingleton(); +builder.Services.AddSingleton(); +builder.Services.AddSingleton(); + +builder.Services + .AddMcpServer() + .WithStdioServerTransport() + .WithTools(); + +await builder.Build().RunAsync(); diff --git a/src/McpSsh.Server/Ssh/DefaultSshKeyResolver.cs b/src/McpSsh.Server/Ssh/DefaultSshKeyResolver.cs new file mode 100644 index 0000000..0750c5f --- /dev/null +++ b/src/McpSsh.Server/Ssh/DefaultSshKeyResolver.cs @@ -0,0 +1,91 @@ +namespace McpSsh.Server.Ssh; + +public interface ISshKeyResolver +{ + string ResolveKeyPath(string? requestedKeyPath); +} + +public interface IFileSystem +{ + bool FileExists(string path); +} + +public sealed class LocalFileSystem : IFileSystem +{ + public bool FileExists(string path) => File.Exists(path); +} + +public sealed class DefaultSshKeyResolver : ISshKeyResolver +{ + private static readonly string[] DefaultKeyNames = ["id_ed25519", "id_ecdsa", "id_rsa"]; + private readonly IFileSystem _fileSystem; + private readonly string _sshDirectory; + + public DefaultSshKeyResolver(IFileSystem fileSystem) + : this(fileSystem, Path.Combine(GetHomeDirectory(), ".ssh")) + { + } + + public DefaultSshKeyResolver(IFileSystem fileSystem, string sshDirectory) + { + _fileSystem = fileSystem; + _sshDirectory = sshDirectory; + } + + public string ResolveKeyPath(string? requestedKeyPath) + { + if (!string.IsNullOrWhiteSpace(requestedKeyPath)) + { + var expanded = ExpandHomeDirectory(requestedKeyPath.Trim()); + if (_fileSystem.FileExists(expanded)) + { + return expanded; + } + + throw new SshToolException("ssh_key_not_found", $"SSH private key not found at '{expanded}'."); + } + + return ResolveDefaultKeyPath(); + } + + private string ResolveDefaultKeyPath() + { + foreach (var keyName in DefaultKeyNames) + { + var path = Path.Combine(_sshDirectory, keyName); + if (_fileSystem.FileExists(path)) + { + return path; + } + } + + throw new SshToolException("ssh_key_not_found", $"No default SSH private key found in '{_sshDirectory}'."); + } + + private static string ExpandHomeDirectory(string path) + { + if (path == "~") + { + return GetHomeDirectory(); + } + + if (path.StartsWith("~/", StringComparison.Ordinal)) + { + return Path.Combine(GetHomeDirectory(), path[2..]); + } + + return path; + } + + private static string GetHomeDirectory() + { + var home = Environment.GetFolderPath(Environment.SpecialFolder.UserProfile); + if (!string.IsNullOrWhiteSpace(home)) + { + return home; + } + + return Environment.GetEnvironmentVariable("HOME") + ?? throw new SshToolException("home_directory_not_found", "Unable to determine the current user's home directory."); + } +} diff --git a/src/McpSsh.Server/Ssh/RemoteShellCommand.cs b/src/McpSsh.Server/Ssh/RemoteShellCommand.cs new file mode 100644 index 0000000..9639a7e --- /dev/null +++ b/src/McpSsh.Server/Ssh/RemoteShellCommand.cs @@ -0,0 +1,36 @@ +using System.Text; + +namespace McpSsh.Server.Ssh; + +public static class RemoteShellCommand +{ + public static string Build(string command, string? cwd) + { + if (string.IsNullOrWhiteSpace(cwd)) + { + return command; + } + + return $"cd {Quote(cwd)} && {command}"; + } + + public static string Quote(string value) + { + var builder = new StringBuilder(value.Length + 2); + builder.Append('\''); + foreach (var character in value) + { + if (character == '\'') + { + builder.Append("'\\''"); + } + else + { + builder.Append(character); + } + } + + builder.Append('\''); + return builder.ToString(); + } +} diff --git a/src/McpSsh.Server/Ssh/SshCommandExecutor.cs b/src/McpSsh.Server/Ssh/SshCommandExecutor.cs new file mode 100644 index 0000000..0505dbc --- /dev/null +++ b/src/McpSsh.Server/Ssh/SshCommandExecutor.cs @@ -0,0 +1,73 @@ +using Renci.SshNet; +using Renci.SshNet.Common; + +namespace McpSsh.Server.Ssh; + +public interface ISshCommandExecutor +{ + Task ExecuteAsync(SshExecRequest request, CancellationToken cancellationToken); +} + +public sealed class SshNetCommandExecutor : ISshCommandExecutor +{ + private readonly ISshKeyResolver _keyResolver; + + public SshNetCommandExecutor(ISshKeyResolver keyResolver) + { + _keyResolver = keyResolver; + } + + public Task ExecuteAsync(SshExecRequest request, CancellationToken cancellationToken) + { + return Task.Run(() => Execute(request, cancellationToken), cancellationToken); + } + + private SshExecResult Execute(SshExecRequest request, CancellationToken cancellationToken) + { + try + { + var keyPath = _keyResolver.ResolveKeyPath(request.KeyPath); + using var keyFile = string.IsNullOrEmpty(request.KeyPassphrase) + ? new PrivateKeyFile(keyPath) + : new PrivateKeyFile(keyPath, request.KeyPassphrase); + using var client = new SshClient(request.Host, request.Port, request.Username, keyFile); + + client.ConnectionInfo.Timeout = TimeSpan.FromSeconds(request.TimeoutSeconds); + client.Connect(); + + cancellationToken.ThrowIfCancellationRequested(); + + var remoteCommand = RemoteShellCommand.Build(request.Command, request.Cwd); + using var command = client.CreateCommand(remoteCommand); + command.CommandTimeout = TimeSpan.FromSeconds(request.TimeoutSeconds); + var started = DateTimeOffset.UtcNow; + var stdout = command.Execute(); + var duration = DateTimeOffset.UtcNow - started; + + return new SshExecResult( + command.ExitStatus ?? -1, + stdout, + command.Error, + (long)duration.TotalMilliseconds, + TimedOut: false, + Error: null, + Message: null); + } + catch (SshAuthenticationException ex) + { + throw new SshToolException("ssh_authentication_failed", "SSH key authentication failed. Check the username, key path, and key passphrase.", ex); + } + catch (SshOperationTimeoutException ex) + { + throw new SshToolException("ssh_timeout", "SSH command timed out.", ex); + } + catch (TimeoutException ex) + { + throw new SshToolException("ssh_timeout", "SSH command timed out.", ex); + } + catch (SshException ex) + { + throw new SshToolException("ssh_error", ex.Message, ex); + } + } +} diff --git a/src/McpSsh.Server/Ssh/SshExecRequest.cs b/src/McpSsh.Server/Ssh/SshExecRequest.cs new file mode 100644 index 0000000..8bcbf89 --- /dev/null +++ b/src/McpSsh.Server/Ssh/SshExecRequest.cs @@ -0,0 +1,11 @@ +namespace McpSsh.Server.Ssh; + +public sealed record SshExecRequest( + string Host, + string Username, + string Command, + int Port, + string? Cwd, + int TimeoutSeconds, + string? KeyPath, + string? KeyPassphrase); diff --git a/src/McpSsh.Server/Ssh/SshExecResult.cs b/src/McpSsh.Server/Ssh/SshExecResult.cs new file mode 100644 index 0000000..70f36f4 --- /dev/null +++ b/src/McpSsh.Server/Ssh/SshExecResult.cs @@ -0,0 +1,10 @@ +namespace McpSsh.Server.Ssh; + +public sealed record SshExecResult( + int ExitCode, + string Stdout, + string Stderr, + long DurationMs, + bool TimedOut, + string? Error, + string? Message); diff --git a/src/McpSsh.Server/Ssh/SshExecService.cs b/src/McpSsh.Server/Ssh/SshExecService.cs new file mode 100644 index 0000000..c7d7cdb --- /dev/null +++ b/src/McpSsh.Server/Ssh/SshExecService.cs @@ -0,0 +1,134 @@ +using System.Diagnostics; +using McpSsh.Server.Audit; + +namespace McpSsh.Server.Ssh; + +public sealed class SshExecService +{ + public const int DefaultPort = 22; + public const int DefaultTimeoutSeconds = 30; + public const int MaxTimeoutSeconds = 300; + + private readonly ISshCommandExecutor _executor; + private readonly IAuditLogger _auditLogger; + private readonly ISystemClock _clock; + + public SshExecService( + ISshCommandExecutor executor, + IAuditLogger auditLogger, + ISystemClock clock) + { + _executor = executor; + _auditLogger = auditLogger; + _clock = clock; + } + + public async Task ExecuteAsync( + string host, + string username, + string command, + int? port, + string? cwd, + string? keyPath, + string? keyPassphrase, + int? timeoutSeconds, + CancellationToken cancellationToken) + { + var request = Validate(host, username, command, port, cwd, keyPath, keyPassphrase, timeoutSeconds); + var stopwatch = Stopwatch.StartNew(); + + try + { + using var timeout = new CancellationTokenSource(TimeSpan.FromSeconds(request.TimeoutSeconds)); + using var linked = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeout.Token); + + var result = await _executor.ExecuteAsync(request, linked.Token).ConfigureAwait(false); + stopwatch.Stop(); + Log(request, success: true, stopwatch.ElapsedMilliseconds); + + return result with { DurationMs = result.DurationMs > 0 ? result.DurationMs : stopwatch.ElapsedMilliseconds }; + } + catch (OperationCanceledException) when (!cancellationToken.IsCancellationRequested) + { + stopwatch.Stop(); + Log(request, success: false, stopwatch.ElapsedMilliseconds, "ssh_timeout", "SSH command timed out."); + + return new SshExecResult( + ExitCode: -1, + Stdout: string.Empty, + Stderr: string.Empty, + DurationMs: stopwatch.ElapsedMilliseconds, + TimedOut: true, + Error: "ssh_timeout", + Message: "SSH command timed out."); + } + catch (SshToolException ex) + { + stopwatch.Stop(); + Log(request, success: false, stopwatch.ElapsedMilliseconds, ex.ErrorCode, ex.Message); + + return new SshExecResult( + ExitCode: -1, + Stdout: string.Empty, + Stderr: string.Empty, + DurationMs: stopwatch.ElapsedMilliseconds, + TimedOut: ex.ErrorCode == "ssh_timeout", + Error: ex.ErrorCode, + Message: ex.Message); + } + } + + private static SshExecRequest Validate(string host, string username, string command, int? port, string? cwd, string? keyPath, string? keyPassphrase, int? timeoutSeconds) + { + if (string.IsNullOrWhiteSpace(host)) + { + throw new SshToolException("invalid_host", "Host is required."); + } + + if (string.IsNullOrWhiteSpace(username)) + { + throw new SshToolException("invalid_username", "Username is required."); + } + + if (string.IsNullOrWhiteSpace(command)) + { + throw new SshToolException("invalid_command", "Command is required."); + } + + var resolvedPort = port ?? DefaultPort; + if (resolvedPort is < 1 or > 65535) + { + throw new SshToolException("invalid_port", "Port must be between 1 and 65535."); + } + + var resolvedTimeout = timeoutSeconds ?? DefaultTimeoutSeconds; + if (resolvedTimeout is < 1 or > MaxTimeoutSeconds) + { + throw new SshToolException("invalid_timeout", $"Timeout must be between 1 and {MaxTimeoutSeconds} seconds."); + } + + return new SshExecRequest( + host.Trim(), + username.Trim(), + command, + resolvedPort, + string.IsNullOrWhiteSpace(cwd) ? null : cwd, + resolvedTimeout, + string.IsNullOrWhiteSpace(keyPath) ? null : keyPath.Trim(), + string.IsNullOrEmpty(keyPassphrase) ? null : keyPassphrase); + } + + private void Log(SshExecRequest request, bool success, long durationMs, string? errorCode = null, string? message = null) + { + _auditLogger.Log(new AuditEvent( + _clock.UtcNow, + "ssh_exec", + request.Host, + request.Username, + request.Command, + success, + durationMs, + errorCode, + message)); + } +} diff --git a/src/McpSsh.Server/Ssh/SshToolException.cs b/src/McpSsh.Server/Ssh/SshToolException.cs new file mode 100644 index 0000000..cb40eb5 --- /dev/null +++ b/src/McpSsh.Server/Ssh/SshToolException.cs @@ -0,0 +1,18 @@ +namespace McpSsh.Server.Ssh; + +public sealed class SshToolException : Exception +{ + public SshToolException(string errorCode, string message) + : base(message) + { + ErrorCode = errorCode; + } + + public SshToolException(string errorCode, string message, Exception innerException) + : base(message, innerException) + { + ErrorCode = errorCode; + } + + public string ErrorCode { get; } +} diff --git a/src/McpSsh.Server/SystemClock.cs b/src/McpSsh.Server/SystemClock.cs new file mode 100644 index 0000000..f1260ca --- /dev/null +++ b/src/McpSsh.Server/SystemClock.cs @@ -0,0 +1,11 @@ +namespace McpSsh.Server; + +public interface ISystemClock +{ + DateTimeOffset UtcNow { get; } +} + +public sealed class SystemClock : ISystemClock +{ + public DateTimeOffset UtcNow => DateTimeOffset.UtcNow; +} diff --git a/src/McpSsh.Server/Terminal/TerminalConnection.cs b/src/McpSsh.Server/Terminal/TerminalConnection.cs new file mode 100644 index 0000000..bcba1ac --- /dev/null +++ b/src/McpSsh.Server/Terminal/TerminalConnection.cs @@ -0,0 +1,95 @@ +using Renci.SshNet; +using Renci.SshNet.Common; +using McpSsh.Server.Ssh; + +namespace McpSsh.Server.Terminal; + +public interface ITerminalConnection : IDisposable +{ + bool DataAvailable { get; } + + string ReadAvailable(); + + void Write(string input); +} + +public interface ITerminalConnectionFactory +{ + ITerminalConnection Create(TerminalStartRequest request); +} + +public sealed class SshNetTerminalConnectionFactory : ITerminalConnectionFactory +{ + private const int ShellBufferSize = 16 * 1024; + private readonly ISshKeyResolver _keyResolver; + + public SshNetTerminalConnectionFactory(ISshKeyResolver keyResolver) + { + _keyResolver = keyResolver; + } + + public ITerminalConnection Create(TerminalStartRequest request) + { + try + { + var keyPath = _keyResolver.ResolveKeyPath(request.KeyPath); + var keyFile = string.IsNullOrEmpty(request.KeyPassphrase) + ? new PrivateKeyFile(keyPath) + : new PrivateKeyFile(keyPath, request.KeyPassphrase); + var client = new SshClient(request.Host, request.Port, request.Username, keyFile); + + client.Connect(); + + var stream = client.CreateShellStream( + "xterm-256color", + (uint)request.Cols, + (uint)request.Rows, + width: 0, + height: 0, + bufferSize: ShellBufferSize); + + return new SshNetTerminalConnection(client, stream, keyFile); + } + catch (SshAuthenticationException ex) + { + throw new SshToolException("ssh_authentication_failed", "SSH key authentication failed. Check the username, key path, and key passphrase.", ex); + } + catch (SshException ex) + { + throw new SshToolException("ssh_error", ex.Message, ex); + } + } + + private sealed class SshNetTerminalConnection : ITerminalConnection + { + private readonly SshClient _client; + private readonly ShellStream _stream; + private readonly PrivateKeyFile _keyFile; + + public SshNetTerminalConnection(SshClient client, ShellStream stream, PrivateKeyFile keyFile) + { + _client = client; + _stream = stream; + _keyFile = keyFile; + } + + public bool DataAvailable => _stream.DataAvailable; + + public string ReadAvailable() + { + return _stream.Read(); + } + + public void Write(string input) + { + _stream.Write(input); + } + + public void Dispose() + { + _stream.Dispose(); + _client.Dispose(); + _keyFile.Dispose(); + } + } +} diff --git a/src/McpSsh.Server/Terminal/TerminalResults.cs b/src/McpSsh.Server/Terminal/TerminalResults.cs new file mode 100644 index 0000000..64a1b0c --- /dev/null +++ b/src/McpSsh.Server/Terminal/TerminalResults.cs @@ -0,0 +1,33 @@ +namespace McpSsh.Server.Terminal; + +public sealed record TerminalStartRequest( + string Host, + string Username, + string? Shell, + int Cols, + int Rows, + int Port, + int IdleTimeoutSeconds, + string? KeyPath, + string? KeyPassphrase); + +public sealed record TerminalStartResult( + string? SessionId, + string? Error = null, + string? Message = null); + +public sealed record TerminalWriteResult( + bool Accepted, + string? Error = null, + string? Message = null); + +public sealed record TerminalReadResult( + string Output, + bool Truncated, + string? Error = null, + string? Message = null); + +public sealed record TerminalStopResult( + bool Stopped, + string? Error = null, + string? Message = null); diff --git a/src/McpSsh.Server/Terminal/TerminalSessionManager.cs b/src/McpSsh.Server/Terminal/TerminalSessionManager.cs new file mode 100644 index 0000000..691acd9 --- /dev/null +++ b/src/McpSsh.Server/Terminal/TerminalSessionManager.cs @@ -0,0 +1,347 @@ +using System.Collections.Concurrent; +using System.Security.Cryptography; +using System.Text; +using McpSsh.Server.Audit; +using McpSsh.Server.Ssh; + +namespace McpSsh.Server.Terminal; + +public sealed class TerminalSessionManager : IDisposable +{ + public const int DefaultPort = 22; + public const int DefaultCols = 120; + public const int DefaultRows = 40; + public const int DefaultIdleTimeoutSeconds = 900; + public const int MaxIdleTimeoutSeconds = 86_400; + public const int DefaultReadMaxBytes = 12_000; + public const int MaxReadBytes = 1_000_000; + + private const int MaxBufferCharacters = 1_000_000; + private readonly ConcurrentDictionary _sessions = new(StringComparer.Ordinal); + private readonly ITerminalConnectionFactory _connectionFactory; + private readonly IAuditLogger _auditLogger; + private readonly ISystemClock _clock; + private readonly Timer _cleanupTimer; + + public TerminalSessionManager(ITerminalConnectionFactory connectionFactory, IAuditLogger auditLogger, ISystemClock clock) + { + _connectionFactory = connectionFactory; + _auditLogger = auditLogger; + _clock = clock; + _cleanupTimer = new Timer(_ => CleanupIdleSessions(), null, TimeSpan.FromMinutes(1), TimeSpan.FromMinutes(1)); + } + + public Task StartAsync( + string host, + string username, + string? shell, + int? cols, + int? rows, + int? port, + string? keyPath, + string? keyPassphrase, + int? idleTimeoutSeconds, + CancellationToken cancellationToken) + { + var request = ValidateStart(host, username, shell, cols, rows, port, keyPath, keyPassphrase, idleTimeoutSeconds); + var started = _clock.UtcNow; + + try + { + cancellationToken.ThrowIfCancellationRequested(); + + var connection = _connectionFactory.Create(request); + var sessionId = CreateSessionId(); + var session = new TerminalSession(sessionId, request.Host, started, request.IdleTimeoutSeconds, connection); + if (!_sessions.TryAdd(sessionId, session)) + { + connection.Dispose(); + throw new SshToolException("terminal_session_conflict", "Unable to allocate a unique terminal session ID."); + } + + if (!string.IsNullOrWhiteSpace(request.Shell)) + { + connection.Write($"exec {request.Shell}\n"); + } + + session.ReaderTask = Task.Run(() => ReadLoopAsync(session), CancellationToken.None); + Log("terminal_start", request.Host, success: true, command: request.Shell); + + return Task.FromResult(new TerminalStartResult(sessionId)); + } + catch (SshToolException ex) + { + Log("terminal_start", request.Host, success: false, errorCode: ex.ErrorCode, message: ex.Message, command: request.Shell); + return Task.FromResult(new TerminalStartResult(null, ex.ErrorCode, ex.Message)); + } + } + + public TerminalWriteResult Write(string sessionId, string input) + { + if (string.IsNullOrWhiteSpace(sessionId)) + { + return new TerminalWriteResult(false, "invalid_session_id", "Session ID is required."); + } + + if (input is null) + { + return new TerminalWriteResult(false, "invalid_input", "Input is required."); + } + + if (!_sessions.TryGetValue(sessionId, out var session)) + { + return new TerminalWriteResult(false, "terminal_session_not_found", $"Terminal session '{sessionId}' was not found."); + } + + try + { + session.Touch(_clock.UtcNow); + session.Connection.Write(input); + Log("terminal_write", session.Host, success: true, command: input); + return new TerminalWriteResult(true); + } + catch (ObjectDisposedException ex) + { + RemoveSession(sessionId); + Log("terminal_write", session.Host, success: false, errorCode: "terminal_session_closed", message: ex.Message, command: input); + return new TerminalWriteResult(false, "terminal_session_closed", "Terminal session is closed."); + } + } + + public TerminalReadResult Read(string sessionId, int? maxBytes) + { + if (string.IsNullOrWhiteSpace(sessionId)) + { + return new TerminalReadResult(string.Empty, false, "invalid_session_id", "Session ID is required."); + } + + if (!_sessions.TryGetValue(sessionId, out var session)) + { + return new TerminalReadResult(string.Empty, false, "terminal_session_not_found", $"Terminal session '{sessionId}' was not found."); + } + + var byteLimit = ValidateMaxBytes(maxBytes); + session.Touch(_clock.UtcNow); + var (output, truncated) = session.Drain(byteLimit); + Log("terminal_read", session.Host, success: true); + return new TerminalReadResult(output, truncated); + } + + public TerminalStopResult Stop(string sessionId) + { + if (string.IsNullOrWhiteSpace(sessionId)) + { + return new TerminalStopResult(false, "invalid_session_id", "Session ID is required."); + } + + if (!_sessions.TryRemove(sessionId, out var session)) + { + return new TerminalStopResult(false, "terminal_session_not_found", $"Terminal session '{sessionId}' was not found."); + } + + session.Dispose(); + Log("terminal_stop", session.Host, success: true); + return new TerminalStopResult(true); + } + + public void Dispose() + { + _cleanupTimer.Dispose(); + foreach (var sessionId in _sessions.Keys) + { + RemoveSession(sessionId); + } + } + + private async Task ReadLoopAsync(TerminalSession session) + { + while (!session.Cancellation.IsCancellationRequested) + { + try + { + if (session.Connection.DataAvailable) + { + var output = session.Connection.ReadAvailable(); + if (!string.IsNullOrEmpty(output)) + { + session.Append(output); + } + } + else + { + await Task.Delay(50, session.Cancellation).ConfigureAwait(false); + } + } + catch (OperationCanceledException) + { + return; + } + catch (ObjectDisposedException) + { + return; + } + } + } + + private static TerminalStartRequest ValidateStart(string host, string username, string? shell, int? cols, int? rows, int? port, string? keyPath, string? keyPassphrase, int? idleTimeoutSeconds) + { + if (string.IsNullOrWhiteSpace(host)) + { + throw new SshToolException("invalid_host", "Host is required."); + } + + if (string.IsNullOrWhiteSpace(username)) + { + throw new SshToolException("invalid_username", "Username is required."); + } + + var resolvedPort = port ?? DefaultPort; + if (resolvedPort is < 1 or > 65535) + { + throw new SshToolException("invalid_port", "Port must be between 1 and 65535."); + } + + var resolvedCols = cols ?? DefaultCols; + if (resolvedCols is < 20 or > 500) + { + throw new SshToolException("invalid_cols", "Terminal columns must be between 20 and 500."); + } + + var resolvedRows = rows ?? DefaultRows; + if (resolvedRows is < 5 or > 500) + { + throw new SshToolException("invalid_rows", "Terminal rows must be between 5 and 500."); + } + + var resolvedIdleTimeout = idleTimeoutSeconds ?? DefaultIdleTimeoutSeconds; + if (resolvedIdleTimeout is < 1 or > MaxIdleTimeoutSeconds) + { + throw new SshToolException("invalid_idle_timeout", $"Idle timeout must be between 1 and {MaxIdleTimeoutSeconds} seconds."); + } + + return new TerminalStartRequest( + host.Trim(), + username.Trim(), + string.IsNullOrWhiteSpace(shell) ? null : shell.Trim(), + resolvedCols, + resolvedRows, + resolvedPort, + resolvedIdleTimeout, + string.IsNullOrWhiteSpace(keyPath) ? null : keyPath.Trim(), + string.IsNullOrEmpty(keyPassphrase) ? null : keyPassphrase); + } + + private static int ValidateMaxBytes(int? maxBytes) + { + var resolved = maxBytes ?? DefaultReadMaxBytes; + if (resolved is < 1 or > MaxReadBytes) + { + throw new SshToolException("invalid_max_bytes", $"maxBytes must be between 1 and {MaxReadBytes}."); + } + + return resolved; + } + + private void CleanupIdleSessions() + { + var now = _clock.UtcNow; + foreach (var (sessionId, session) in _sessions) + { + if (now - session.LastActivityUtc >= TimeSpan.FromSeconds(session.IdleTimeoutSeconds)) + { + RemoveSession(sessionId); + Log("terminal_idle_cleanup", session.Host, success: true); + } + } + } + + private void RemoveSession(string sessionId) + { + if (_sessions.TryRemove(sessionId, out var session)) + { + session.Dispose(); + } + } + + private void Log(string tool, string? host, bool success, string? errorCode = null, string? message = null, string? command = null) + { + _auditLogger.Log(new AuditEvent(_clock.UtcNow, tool, host, null, command, success, 0, errorCode, message)); + } + + private static string CreateSessionId() + { + Span bytes = stackalloc byte[12]; + RandomNumberGenerator.Fill(bytes); + return $"term_{Convert.ToHexString(bytes).ToLowerInvariant()}"; + } + + private sealed class TerminalSession : IDisposable + { + private readonly object _gate = new(); + private readonly StringBuilder _buffer = new(); + private readonly CancellationTokenSource _cancellation = new(); + + public TerminalSession(string sessionId, string host, DateTimeOffset now, int idleTimeoutSeconds, ITerminalConnection connection) + { + SessionId = sessionId; + Host = host; + CreatedUtc = now; + LastActivityUtc = now; + IdleTimeoutSeconds = idleTimeoutSeconds; + Connection = connection; + Cancellation = _cancellation.Token; + } + + public string SessionId { get; } + public string Host { get; } + public DateTimeOffset CreatedUtc { get; } + public DateTimeOffset LastActivityUtc { get; private set; } + public int IdleTimeoutSeconds { get; } + public ITerminalConnection Connection { get; } + public CancellationToken Cancellation { get; } + public Task? ReaderTask { get; set; } + + public void Touch(DateTimeOffset now) + { + lock (_gate) + { + LastActivityUtc = now; + } + } + + public void Append(string output) + { + lock (_gate) + { + _buffer.Append(output); + if (_buffer.Length > MaxBufferCharacters) + { + _buffer.Remove(0, _buffer.Length - MaxBufferCharacters); + } + } + } + + public (string Output, bool Truncated) Drain(int maxCharacters) + { + lock (_gate) + { + if (_buffer.Length == 0) + { + return (string.Empty, false); + } + + var count = Math.Min(maxCharacters, _buffer.Length); + var output = _buffer.ToString(0, count); + _buffer.Remove(0, count); + return (output, _buffer.Length > 0); + } + } + + public void Dispose() + { + _cancellation.Cancel(); + Connection.Dispose(); + _cancellation.Dispose(); + } + } +} diff --git a/src/McpSsh.Server/Tools/SshTools.cs b/src/McpSsh.Server/Tools/SshTools.cs new file mode 100644 index 0000000..e1696ae --- /dev/null +++ b/src/McpSsh.Server/Tools/SshTools.cs @@ -0,0 +1,78 @@ +using System.ComponentModel; +using McpSsh.Server.Ssh; +using McpSsh.Server.Terminal; +using ModelContextProtocol.Server; + +namespace McpSsh.Server.Tools; + +[McpServerToolType] +public sealed class SshTools +{ + private readonly SshExecService _sshExecService; + private readonly TerminalSessionManager _terminalSessionManager; + + public SshTools(SshExecService sshExecService, TerminalSessionManager terminalSessionManager) + { + _sshExecService = sshExecService; + _terminalSessionManager = terminalSessionManager; + } + + [McpServerTool(Name = "ssh_exec", Destructive = true)] + [Description("Execute a single command over SSH using key-based authentication.")] + public Task ExecuteAsync( + [Description("Remote hostname or IP address. OpenSSH aliases are not supported in this vertical slice.")] string host, + [Description("Remote SSH username.")] string username, + [Description("Command to execute on the remote host.")] string command, + [Description("Remote SSH port. Defaults to 22.")] int? port = null, + [Description("Optional remote working directory.")] string? cwd = null, + [Description("Optional local private key path. Defaults to ~/.ssh/id_ed25519, ~/.ssh/id_ecdsa, then ~/.ssh/id_rsa.")] string? keyPath = null, + [Description("Optional private key passphrase. Use only with trusted MCP clients.")] string? keyPassphrase = null, + [Description("Timeout in seconds. Defaults to 30 and is capped at 300.")] int? timeoutSeconds = null, + CancellationToken cancellationToken = default) + { + return _sshExecService.ExecuteAsync(host, username, command, port, cwd, keyPath, keyPassphrase, timeoutSeconds, cancellationToken); + } + + [McpServerTool(Name = "terminal_start", Destructive = true)] + [Description("Start a persistent SSH PTY shell session using key-based authentication.")] + public Task StartTerminalAsync( + [Description("Remote hostname or IP address. OpenSSH aliases are not supported in this vertical slice.")] string host, + [Description("Remote SSH username.")] string username, + [Description("Optional shell to exec after PTY allocation, for example bash or sh.")] string? shell = null, + [Description("Terminal columns. Defaults to 120.")] int? cols = null, + [Description("Terminal rows. Defaults to 40.")] int? rows = null, + [Description("Remote SSH port. Defaults to 22.")] int? port = null, + [Description("Optional local private key path. Defaults to ~/.ssh/id_ed25519, ~/.ssh/id_ecdsa, then ~/.ssh/id_rsa.")] string? keyPath = null, + [Description("Optional private key passphrase. Use only with trusted MCP clients.")] string? keyPassphrase = null, + [Description("Idle timeout in seconds. Defaults to 900.")] int? idleTimeoutSeconds = null, + CancellationToken cancellationToken = default) + { + return _terminalSessionManager.StartAsync(host, username, shell, cols, rows, port, keyPath, keyPassphrase, idleTimeoutSeconds, cancellationToken); + } + + [McpServerTool(Name = "terminal_write", Destructive = true)] + [Description("Write input to an active SSH terminal session.")] + public TerminalWriteResult WriteTerminal( + [Description("Terminal session ID returned by terminal_start.")] string sessionId, + [Description("Input to write to the terminal. Include newline characters when submitting commands.")] string input) + { + return _terminalSessionManager.Write(sessionId, input); + } + + [McpServerTool(Name = "terminal_read", Destructive = false)] + [Description("Read buffered output from an active SSH terminal session.")] + public TerminalReadResult ReadTerminal( + [Description("Terminal session ID returned by terminal_start.")] string sessionId, + [Description("Maximum output characters to read. Defaults to 12000.")] int? maxBytes = null) + { + return _terminalSessionManager.Read(sessionId, maxBytes); + } + + [McpServerTool(Name = "terminal_stop", Destructive = true)] + [Description("Stop and remove an SSH terminal session.")] + public TerminalStopResult StopTerminal( + [Description("Terminal session ID returned by terminal_start.")] string sessionId) + { + return _terminalSessionManager.Stop(sessionId); + } +} diff --git a/tests/McpSsh.Tests/DefaultSshKeyResolverTests.cs b/tests/McpSsh.Tests/DefaultSshKeyResolverTests.cs new file mode 100644 index 0000000..5100149 --- /dev/null +++ b/tests/McpSsh.Tests/DefaultSshKeyResolverTests.cs @@ -0,0 +1,59 @@ +using McpSsh.Server.Ssh; + +namespace McpSsh.Tests; + +public sealed class DefaultSshKeyResolverTests +{ + [Fact] + public void ResolveDefaultKeyPath_ReturnsFirstExistingDefaultKey() + { + var fileSystem = new FakeFileSystem("/home/test/.ssh/id_ecdsa", "/home/test/.ssh/id_rsa"); + var resolver = new DefaultSshKeyResolver(fileSystem, "/home/test/.ssh"); + + var path = resolver.ResolveKeyPath(null); + + Assert.Equal("/home/test/.ssh/id_ecdsa", path); + } + + [Fact] + public void ResolveDefaultKeyPath_ThrowsWhenNoDefaultKeyExists() + { + var resolver = new DefaultSshKeyResolver(new FakeFileSystem(), "/home/test/.ssh"); + + var ex = Assert.Throws(() => resolver.ResolveKeyPath(null)); + + Assert.Equal("ssh_key_not_found", ex.ErrorCode); + } + + [Fact] + public void ResolveKeyPath_ReturnsExplicitKeyWhenItExists() + { + var resolver = new DefaultSshKeyResolver(new FakeFileSystem("/keys/deploy_ed25519"), "/home/test/.ssh"); + + var path = resolver.ResolveKeyPath("/keys/deploy_ed25519"); + + Assert.Equal("/keys/deploy_ed25519", path); + } + + [Fact] + public void ResolveKeyPath_ThrowsWhenExplicitKeyDoesNotExist() + { + var resolver = new DefaultSshKeyResolver(new FakeFileSystem(), "/home/test/.ssh"); + + var ex = Assert.Throws(() => resolver.ResolveKeyPath("/keys/missing")); + + Assert.Equal("ssh_key_not_found", ex.ErrorCode); + } + + private sealed class FakeFileSystem : IFileSystem + { + private readonly HashSet _paths; + + public FakeFileSystem(params string[] paths) + { + _paths = paths.ToHashSet(StringComparer.Ordinal); + } + + public bool FileExists(string path) => _paths.Contains(path); + } +} diff --git a/tests/McpSsh.Tests/JsonLineAuditLoggerTests.cs b/tests/McpSsh.Tests/JsonLineAuditLoggerTests.cs new file mode 100644 index 0000000..b6ae41f --- /dev/null +++ b/tests/McpSsh.Tests/JsonLineAuditLoggerTests.cs @@ -0,0 +1,31 @@ +using System.Text.Json; +using McpSsh.Server.Audit; + +namespace McpSsh.Tests; + +public sealed class JsonLineAuditLoggerTests +{ + [Fact] + public void Log_WritesJsonLineAndRedactsSensitiveMarkers() + { + using var writer = new StringWriter(); + var logger = new JsonLineAuditLogger(writer); + + logger.Log(new AuditEvent( + DateTimeOffset.Parse("2026-05-24T12:00:00Z"), + "ssh_exec", + "prod-api", + "deploy", + "echo token=abc123", + Success: true, + DurationMs: 42)); + + using var document = JsonDocument.Parse(writer.ToString()); + var root = document.RootElement; + + Assert.Equal("ssh_exec", root.GetProperty("tool").GetString()); + Assert.Equal("prod-api", root.GetProperty("host").GetString()); + Assert.Equal("echo token=***", root.GetProperty("command").GetString()); + Assert.True(root.GetProperty("success").GetBoolean()); + } +} diff --git a/tests/McpSsh.Tests/McpSsh.Tests.csproj b/tests/McpSsh.Tests/McpSsh.Tests.csproj new file mode 100644 index 0000000..56bea59 --- /dev/null +++ b/tests/McpSsh.Tests/McpSsh.Tests.csproj @@ -0,0 +1,25 @@ + + + + net10.0 + enable + enable + false + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/tests/McpSsh.Tests/RemoteShellCommandTests.cs b/tests/McpSsh.Tests/RemoteShellCommandTests.cs new file mode 100644 index 0000000..ab23932 --- /dev/null +++ b/tests/McpSsh.Tests/RemoteShellCommandTests.cs @@ -0,0 +1,20 @@ +using McpSsh.Server.Ssh; + +namespace McpSsh.Tests; + +public sealed class RemoteShellCommandTests +{ + [Fact] + public void Build_ReturnsCommandWhenCwdIsMissing() + { + Assert.Equal("pwd", RemoteShellCommand.Build("pwd", null)); + } + + [Fact] + public void Build_PrependsQuotedCwd() + { + var command = RemoteShellCommand.Build("ls", "/srv/app's/current"); + + Assert.Equal("cd '/srv/app'\\''s/current' && ls", command); + } +} diff --git a/tests/McpSsh.Tests/SshExecServiceTests.cs b/tests/McpSsh.Tests/SshExecServiceTests.cs new file mode 100644 index 0000000..dbd5f93 --- /dev/null +++ b/tests/McpSsh.Tests/SshExecServiceTests.cs @@ -0,0 +1,110 @@ +using McpSsh.Server; +using McpSsh.Server.Audit; +using McpSsh.Server.Ssh; + +namespace McpSsh.Tests; + +public sealed class SshExecServiceTests +{ + [Fact] + public async Task ExecuteAsync_PassesValidatedRequestToExecutor() + { + var executor = new CapturingExecutor(new SshExecResult(0, "ok", "", 12, false, null, null)); + var auditLogger = new CapturingAuditLogger(); + var service = CreateService(executor, auditLogger: auditLogger); + + var result = await service.ExecuteAsync(" prod-api ", " deploy ", "uptime", null, null, " /keys/deploy ", "secret", null, CancellationToken.None); + + Assert.Equal(0, result.ExitCode); + Assert.NotNull(executor.Request); + Assert.Equal("prod-api", executor.Request.Host); + Assert.Equal("deploy", executor.Request.Username); + Assert.Equal(22, executor.Request.Port); + Assert.Equal(30, executor.Request.TimeoutSeconds); + Assert.Equal("/keys/deploy", executor.Request.KeyPath); + Assert.Equal("secret", executor.Request.KeyPassphrase); + Assert.Single(auditLogger.Events); + Assert.True(auditLogger.Events[0].Success); + } + + [Fact] + public async Task ExecuteAsync_AllowsDestructiveCommands() + { + var executor = new CapturingExecutor(); + var service = CreateService(executor); + + var result = await service.ExecuteAsync("prod-api", "deploy", "rm -rf /", null, null, null, null, null, CancellationToken.None); + + Assert.Equal(0, result.ExitCode); + Assert.Equal("rm -rf /", executor.Request?.Command); + } + + [Fact] + public async Task ExecuteAsync_ReturnsTimeoutWhenExecutorObservesCancellation() + { + var executor = new BlockingExecutor(); + var service = CreateService(executor); + + var result = await service.ExecuteAsync("prod-api", "deploy", "sleep 5", null, null, null, null, 1, CancellationToken.None); + + Assert.True(result.TimedOut); + Assert.Equal("ssh_timeout", result.Error); + } + + private static SshExecService CreateService( + ISshCommandExecutor executor, + IAuditLogger? auditLogger = null) + { + return new SshExecService( + executor, + auditLogger ?? new CapturingAuditLogger(), + new FixedClock()); + } + + private sealed class CapturingExecutor : ISshCommandExecutor + { + private readonly SshExecResult _result; + + public CapturingExecutor() + : this(new SshExecResult(0, "", "", 1, false, null, null)) + { + } + + public CapturingExecutor(SshExecResult result) + { + _result = result; + } + + public SshExecRequest? Request { get; private set; } + + public Task ExecuteAsync(SshExecRequest request, CancellationToken cancellationToken) + { + Request = request; + return Task.FromResult(_result); + } + } + + private sealed class BlockingExecutor : ISshCommandExecutor + { + public async Task ExecuteAsync(SshExecRequest request, CancellationToken cancellationToken) + { + await Task.Delay(TimeSpan.FromMinutes(1), cancellationToken); + return new SshExecResult(0, "", "", 0, false, null, null); + } + } + + private sealed class CapturingAuditLogger : IAuditLogger + { + public List Events { get; } = []; + + public void Log(AuditEvent auditEvent) + { + Events.Add(auditEvent); + } + } + + private sealed class FixedClock : ISystemClock + { + public DateTimeOffset UtcNow => DateTimeOffset.Parse("2026-05-24T12:00:00Z"); + } +} diff --git a/tests/McpSsh.Tests/TerminalSessionManagerTests.cs b/tests/McpSsh.Tests/TerminalSessionManagerTests.cs new file mode 100644 index 0000000..4a9ef1c --- /dev/null +++ b/tests/McpSsh.Tests/TerminalSessionManagerTests.cs @@ -0,0 +1,158 @@ +using System.Collections.Concurrent; +using McpSsh.Server; +using McpSsh.Server.Audit; +using McpSsh.Server.Terminal; + +namespace McpSsh.Tests; + +public sealed class TerminalSessionManagerTests +{ + [Fact] + public async Task StartAsync_CreatesSessionAndExecsRequestedShell() + { + using var manager = CreateManager(out var factory, out _); + + var result = await manager.StartAsync(" prod-api ", " deploy ", "bash", null, null, null, "/keys/id", "secret", null, CancellationToken.None); + + Assert.Null(result.Error); + Assert.StartsWith("term_", result.SessionId); + Assert.NotNull(factory.Request); + Assert.Equal("prod-api", factory.Request.Host); + Assert.Equal("deploy", factory.Request.Username); + Assert.Equal("/keys/id", factory.Request.KeyPath); + Assert.Equal("secret", factory.Request.KeyPassphrase); + Assert.Contains("exec bash\n", factory.Connection.Writes); + } + + [Fact] + public async Task Write_SendsInputToActiveSession() + { + using var manager = CreateManager(out var factory, out _); + var start = await manager.StartAsync("prod-api", "deploy", null, null, null, null, null, null, null, CancellationToken.None); + + var result = manager.Write(start.SessionId!, "uptime\n"); + + Assert.True(result.Accepted); + Assert.Contains("uptime\n", factory.Connection.Writes); + } + + [Fact] + public async Task Read_DrainsBufferedOutputAndReportsTruncation() + { + using var manager = CreateManager(out var factory, out _); + var start = await manager.StartAsync("prod-api", "deploy", null, null, null, null, null, null, null, CancellationToken.None); + factory.Connection.QueueOutput("abcdef"); + + var first = await ReadUntilOutputAsync(manager, start.SessionId!, 3); + var second = manager.Read(start.SessionId!, 10); + + Assert.Equal("abc", first.Output); + Assert.True(first.Truncated); + Assert.Equal("def", second.Output); + Assert.False(second.Truncated); + } + + [Fact] + public void Write_ReturnsNotFoundForMissingSession() + { + using var manager = CreateManager(out _, out _); + + var result = manager.Write("term_missing", "pwd\n"); + + Assert.False(result.Accepted); + Assert.Equal("terminal_session_not_found", result.Error); + } + + [Fact] + public async Task Stop_DisposesAndRemovesSession() + { + using var manager = CreateManager(out var factory, out _); + var start = await manager.StartAsync("prod-api", "deploy", null, null, null, null, null, null, null, CancellationToken.None); + + var result = manager.Stop(start.SessionId!); + var writeAfterStop = manager.Write(start.SessionId!, "pwd\n"); + + Assert.True(result.Stopped); + Assert.True(factory.Connection.Disposed); + Assert.Equal("terminal_session_not_found", writeAfterStop.Error); + } + + private static TerminalSessionManager CreateManager(out FakeTerminalConnectionFactory factory, out CapturingAuditLogger auditLogger) + { + factory = new FakeTerminalConnectionFactory(); + auditLogger = new CapturingAuditLogger(); + return new TerminalSessionManager(factory, auditLogger, new FixedClock()); + } + + private static async Task ReadUntilOutputAsync(TerminalSessionManager manager, string sessionId, int maxBytes) + { + for (var attempt = 0; attempt < 20; attempt++) + { + var result = manager.Read(sessionId, maxBytes); + if (result.Output.Length > 0) + { + return result; + } + + await Task.Delay(50); + } + + return manager.Read(sessionId, maxBytes); + } + + private sealed class FakeTerminalConnectionFactory : ITerminalConnectionFactory + { + public FakeTerminalConnection Connection { get; } = new(); + public TerminalStartRequest? Request { get; private set; } + + public ITerminalConnection Create(TerminalStartRequest request) + { + Request = request; + return Connection; + } + } + + private sealed class FakeTerminalConnection : ITerminalConnection + { + private readonly ConcurrentQueue _output = new(); + + public List Writes { get; } = []; + public bool Disposed { get; private set; } + public bool DataAvailable => !_output.IsEmpty; + + public string ReadAvailable() + { + return _output.TryDequeue(out var output) ? output : string.Empty; + } + + public void Write(string input) + { + Writes.Add(input); + } + + public void QueueOutput(string output) + { + _output.Enqueue(output); + } + + public void Dispose() + { + Disposed = true; + } + } + + private sealed class CapturingAuditLogger : IAuditLogger + { + public List Events { get; } = []; + + public void Log(AuditEvent auditEvent) + { + Events.Add(auditEvent); + } + } + + private sealed class FixedClock : ISystemClock + { + public DateTimeOffset UtcNow => DateTimeOffset.Parse("2026-05-24T12:00:00Z"); + } +}