diff --git a/AUTHENTICATION.md b/AUTHENTICATION.md new file mode 100644 index 0000000..f080a0e --- /dev/null +++ b/AUTHENTICATION.md @@ -0,0 +1,269 @@ +# MySQL Authentication Implementation + +This document describes the MySQL authentication implementation for SuperSocket.MySQL that follows the MySQL protocol specification. + +## Overview + +The authentication system implements the standard MySQL handshake protocol: + +1. **Server Hello**: Server sends initial handshake packet with protocol version, server version, connection ID, and authentication challenge (salt) +2. **Client Authentication**: Client responds with username, scrambled password, and other connection parameters +3. **Server Response**: Server validates credentials and sends OK packet (success) or ERR packet (failure) + +## Architecture + +The implementation consists of several key components: + +### Core Classes + +#### `MySQLHandshakePacket` +- Represents the initial handshake packet sent by the server +- Contains protocol version, server version, connection ID, and 20-byte salt +- Generates binary packet data according to MySQL protocol specification +- Located: `src/SuperSocket.MySQL/Authentication/MySQLHandshakePacket.cs` + +#### `MySQLHandshakeResponsePacket` +- Represents the client's authentication response +- Parses client capabilities, username, scrambled password, and database name +- Handles binary packet parsing from SuperSocket pipeline +- Located: `src/SuperSocket.MySQL/Authentication/MySQLHandshakeResponsePacket.cs` + +#### `MySQLAuthenticationHandler` +- Coordinates the authentication flow +- Implements MySQL native password scrambling algorithm (SHA1-based) +- Validates credentials against hardcoded username/password ("test"/"test") +- Generates OK and ERR response packets +- Located: `src/SuperSocket.MySQL/Authentication/MySQLAuthenticationHandler.cs` + +#### `MySQLSession` +- Extends SuperSocket AppSession to handle MySQL connections +- Automatically sends handshake packet on connection +- Manages authentication state and credential validation +- Located: `src/SuperSocket.MySQL/Authentication/MySQLSession.cs` + +#### `MySQLHandshakeResponseFilter` +- SuperSocket filter for parsing handshake response packets +- Integrates with SuperSocket's PackagePartsPipelineFilter system +- Located: `src/SuperSocket.MySQL/Authentication/MySQLHandshakeResponseFilter.cs` + +## Usage + +### Basic Authentication Setup + +```csharp +using SuperSocket.MySQL.Authentication; + +// Create authentication handler +var authHandler = new MySQLAuthenticationHandler(); + +// Generate handshake packet +var handshake = authHandler.CreateHandshake(); +var handshakeBytes = handshake.ToBytes(); + +// Send handshake to client +await session.SendAsync(handshakeBytes); + +// Parse client response +var response = MySQLHandshakeResponsePacket.ParseFromBytes(clientData, 0, clientData.Length); + +// Validate credentials +var salt = handshake.GetFullSalt(); +bool isValid = authHandler.ValidateCredentials(response, salt); + +if (isValid) +{ + var okPacket = authHandler.CreateOkPacket(); + await session.SendAsync(okPacket); +} +else +{ + var errorPacket = authHandler.CreateErrorPacket(1045, "Access denied"); + await session.SendAsync(errorPacket); +} +``` + +### SuperSocket Integration + +```csharp +using SuperSocket.MySQL.Authentication; +using SuperSocket.Server; + +// Create server with MySQL authentication +var host = SuperSocketHostBuilder + .Create() + .UseSession() + .ConfigureServices((context, services) => + { + services.Configure(options => + { + options.Listeners = new[] + { + new ListenOptions { Ip = "127.0.0.1", Port = 3306 } + }; + }); + }) + .Build(); + +await host.RunAsync(); +``` + +### Custom Session Implementation + +```csharp +public class CustomMySQLSession : MySQLSession +{ + protected override async ValueTask OnPackageReceived(MySQLHandshakeResponsePacket package) + { + if (!IsAuthenticated) + { + var success = await HandleAuthenticationAsync(package); + if (!success) + { + await CloseAsync(); + return; + } + + // Authentication successful, ready for command processing + Logger?.LogInformation($"User {package.Username} authenticated successfully"); + } + else + { + // Handle MySQL commands after authentication + await ProcessMySQLCommand(package); + } + } + + private async Task ProcessMySQLCommand(MySQLHandshakeResponsePacket package) + { + // Implement your MySQL command processing logic here + // This would typically involve switching to a different packet filter + // for handling SQL queries, prepared statements, etc. + } +} +``` + +## Protocol Details + +### Handshake Packet Structure + +The handshake packet follows MySQL Protocol Version 10 format: + +``` +1 byte - Protocol version (10) +string - Server version (null-terminated) +4 bytes - Connection ID +8 bytes - Authentication plugin data part 1 +1 byte - Filler (0x00) +2 bytes - Capability flags (lower 2 bytes) +1 byte - Character set +2 bytes - Status flags +2 bytes - Capability flags (upper 2 bytes) +1 byte - Length of authentication plugin data +10 bytes - Reserved (all zeros) +12 bytes - Authentication plugin data part 2 +1 byte - Null terminator +string - Authentication plugin name (null-terminated) +``` + +### Password Scrambling Algorithm + +The implementation uses MySQL's native password authentication: + +``` +SHA1(password) XOR SHA1(salt + SHA1(SHA1(password))) +``` + +Where: +- `password` is the plaintext password +- `salt` is the 20-byte challenge from the handshake packet +- `SHA1()` is the SHA-1 hash function +- `XOR` is bitwise exclusive OR + +### Authentication Flow + +1. **Connection Established**: Client connects to server +2. **Server Hello**: Server immediately sends handshake packet with unique salt +3. **Client Response**: Client sends handshake response with scrambled password +4. **Validation**: Server validates username and password using salt +5. **Result**: Server sends OK packet (authentication success) or ERR packet (failure) + +## Configuration + +### Hardcoded Credentials + +Currently, the system uses hardcoded credentials for simplicity: +- Username: `"test"` +- Password: `"test"` + +To modify credentials, update the `MySQLAuthenticationHandler` class: + +```csharp +private readonly string _validUsername = "your_username"; +private readonly string _validPassword = "your_password"; +``` + +### Server Information + +Default server information can be customized in `MySQLHandshakePacket`: + +```csharp +public byte ProtocolVersion { get; set; } = 10; +public string ServerVersion { get; set; } = "8.0.0-supersocket"; +public ushort CapabilityFlagsLower { get; set; } = 0xF7FF; +public byte CharacterSet { get; set; } = 0x21; // utf8_general_ci +``` + +## Testing + +Basic tests are provided in `src/SuperSocket.MySQL/Tests/AuthenticationTest.cs`: + +```csharp +// Test handshake packet generation +AuthenticationTest.TestHandshakePacket(); + +// Test password scrambling +AuthenticationTest.TestPasswordScrambling(); + +// Test OK/ERR packet generation +AuthenticationTest.TestOkErrorPackets(); + +// Run all tests +AuthenticationTest.RunAllTests(); +``` + +## Limitations + +This is a minimal implementation with the following limitations: + +1. **Single User**: Only supports one hardcoded username/password combination +2. **No Database Support**: Database name in client response is ignored +3. **Basic Capabilities**: Only implements essential capability flags +4. **No SSL/TLS**: Does not support encrypted connections +5. **No Prepared Statements**: Authentication only, no SQL processing +6. **No Multi-Factor Auth**: Only supports password authentication + +## Security Considerations + +1. **Password Storage**: Passwords are hardcoded in source code (development only) +2. **Salt Generation**: Uses cryptographically secure random number generator +3. **Protocol Compliance**: Follows MySQL native password authentication standard +4. **Connection Limits**: No built-in connection throttling or rate limiting + +## Future Enhancements + +Potential improvements for production use: + +1. **Database Integration**: Store user credentials in database +2. **Multiple Auth Plugins**: Support for additional authentication methods +3. **SSL/TLS Support**: Encrypted connection support +4. **Configuration**: External configuration for server settings +5. **Logging**: Comprehensive audit logging +6. **Performance**: Connection pooling and optimization +7. **Command Processing**: Full MySQL protocol command handling + +## References + +- [MySQL Protocol Documentation](https://dev.mysql.com/doc/dev/mysql-server/8.0.11/PAGE_PROTOCOL.html) +- [MySQL Handshake Protocol](https://dev.mysql.com/doc/dev/mysql-server/8.0.11/page_protocol_connection_phase_packets_protocol_handshake_v10.html) +- [MySQL Authentication](https://dev.mysql.com/doc/dev/mysql-server/8.0.11/page_protocol_connection_phase_packets_protocol_handshake_response.html) +- [SuperSocket Documentation](https://docs.supersocket.net/) \ No newline at end of file diff --git a/src/SuperSocket.MySQL/Authentication/MySQLAuthenticationHandler.cs b/src/SuperSocket.MySQL/Authentication/MySQLAuthenticationHandler.cs new file mode 100644 index 0000000..872768b --- /dev/null +++ b/src/SuperSocket.MySQL/Authentication/MySQLAuthenticationHandler.cs @@ -0,0 +1,173 @@ +using System; +using System.Security.Cryptography; +using System.Text; + +namespace SuperSocket.MySQL.Authentication +{ + /// + /// Handles MySQL authentication flow including handshake, challenge, and response validation. + /// Implements MySQL native password authentication using SHA1-based scrambling. + /// Based on MySQL Protocol specification: + /// https://dev.mysql.com/doc/dev/mysql-server/8.0.11/page_protocol_connection_phase_packets_protocol_handshake_v10.html + /// + public partial class MySQLAuthenticationHandler + { + private static uint _nextConnectionId = 1; + private readonly string _validUsername = "test"; + private readonly string _validPassword = "test"; + + public MySQLHandshakePacket CreateHandshake() + { + var handshake = new MySQLHandshakePacket + { + ConnectionId = GetNextConnectionId() + }; + return handshake; + } + + public bool ValidateCredentials(MySQLHandshakeResponsePacket response, byte[] salt) + { + if (string.IsNullOrEmpty(response.Username)) + return false; + + // Check username + if (!string.Equals(response.Username, _validUsername, StringComparison.Ordinal)) + return false; + + // Validate password using MySQL native password scrambling + if (response.AuthResponse == null || response.AuthResponse.Length == 0) + { + // Empty password - only valid if expected password is also empty + return string.IsNullOrEmpty(_validPassword); + } + + var expectedScramble = ScramblePassword(_validPassword, salt); + return CompareByteArrays(response.AuthResponse, expectedScramble); + } + + public byte[] CreateOkPacket() + { + // MySQL OK packet format: + // Header (4 bytes) + OK byte (0x00) + affected_rows + last_insert_id + status_flags + warnings + var packet = new byte[11]; + + // Packet length (7 bytes) + packet[0] = 0x07; + packet[1] = 0x00; + packet[2] = 0x00; + + // Packet sequence number + packet[3] = 0x02; + + // OK indicator + packet[4] = 0x00; + + // Affected rows (encoded integer) + packet[5] = 0x00; + + // Last insert ID (encoded integer) + packet[6] = 0x00; + + // Status flags + packet[7] = 0x02; // SERVER_STATUS_AUTOCOMMIT + packet[8] = 0x00; + + // Warnings + packet[9] = 0x00; + packet[10] = 0x00; + + return packet; + } + + public byte[] CreateErrorPacket(ushort errorCode, string message) + { + var messageBytes = Encoding.UTF8.GetBytes(message ?? "Authentication failed"); + var sqlState = Encoding.UTF8.GetBytes("28000"); // Access denied error + + // Calculate packet length: error marker (1) + error code (2) + sql state marker (1) + sql state (5) + message + var packetLength = 1 + 2 + 1 + 5 + messageBytes.Length; + var packet = new byte[4 + packetLength]; + + int offset = 0; + + // Packet header + packet[offset++] = (byte)(packetLength & 0xFF); + packet[offset++] = (byte)((packetLength >> 8) & 0xFF); + packet[offset++] = (byte)((packetLength >> 16) & 0xFF); + packet[offset++] = 0x02; // Packet sequence number + + // Error marker + packet[offset++] = 0xFF; + + // Error code + packet[offset++] = (byte)(errorCode & 0xFF); + packet[offset++] = (byte)((errorCode >> 8) & 0xFF); + + // SQL state marker + packet[offset++] = 0x23; // '#' + + // SQL state + Array.Copy(sqlState, 0, packet, offset, 5); + offset += 5; + + // Error message + Array.Copy(messageBytes, 0, packet, offset, messageBytes.Length); + + return packet; + } + + /// + /// Implements MySQL native password scrambling algorithm. + /// SHA1(password) XOR SHA1(salt + SHA1(SHA1(password))) + /// + private byte[] ScramblePassword(string password, byte[] salt) + { + if (string.IsNullOrEmpty(password)) + return new byte[0]; + + using (var sha1 = SHA1.Create()) + { + // Stage 1: SHA1(password) + var passwordBytes = Encoding.UTF8.GetBytes(password); + var stage1Hash = sha1.ComputeHash(passwordBytes); + + // Stage 2: SHA1(SHA1(password)) + var stage2Hash = sha1.ComputeHash(stage1Hash); + + // Stage 3: SHA1(salt + SHA1(SHA1(password))) + var saltAndStage2 = new byte[salt.Length + stage2Hash.Length]; + Array.Copy(salt, 0, saltAndStage2, 0, salt.Length); + Array.Copy(stage2Hash, 0, saltAndStage2, salt.Length, stage2Hash.Length); + var stage3Hash = sha1.ComputeHash(saltAndStage2); + + // Final: SHA1(password) XOR SHA1(salt + SHA1(SHA1(password))) + var scramble = new byte[stage1Hash.Length]; + for (int i = 0; i < stage1Hash.Length; i++) + { + scramble[i] = (byte)(stage1Hash[i] ^ stage3Hash[i]); + } + + return scramble; + } + } + + private bool CompareByteArrays(byte[] array1, byte[] array2) + { + if (array1.Length != array2.Length) + return false; + + for (int i = 0; i < array1.Length; i++) + { + if (array1[i] != array2[i]) + return false; + } + + return true; + } + + private static uint GetNextConnectionId() + { + return _nextConnectionId++; + } + } +} \ No newline at end of file diff --git a/src/SuperSocket.MySQL/Authentication/MySQLHandshakePacket.cs b/src/SuperSocket.MySQL/Authentication/MySQLHandshakePacket.cs new file mode 100644 index 0000000..989909c --- /dev/null +++ b/src/SuperSocket.MySQL/Authentication/MySQLHandshakePacket.cs @@ -0,0 +1,137 @@ +using System; +using System.Buffers; +using System.Security.Cryptography; +using System.Text; + +namespace SuperSocket.MySQL.Authentication +{ + /// + /// Represents the initial handshake packet sent by the server to the client. + /// Based on MySQL Protocol specification: + /// https://dev.mysql.com/doc/dev/mysql-server/8.0.11/page_protocol_connection_phase_packets_protocol_handshake_v10.html + /// + public class MySQLHandshakePacket + { + public byte ProtocolVersion { get; set; } = 10; + public string ServerVersion { get; set; } = "8.0.0-supersocket"; + public uint ConnectionId { get; set; } + public byte[] AuthPluginDataPart1 { get; set; } = new byte[8]; + public byte Filler { get; set; } = 0x00; + public ushort CapabilityFlagsLower { get; set; } = 0xF7FF; // Default capabilities + public byte CharacterSet { get; set; } = 0x21; // utf8_general_ci + public ushort StatusFlags { get; set; } = 0x0002; // SERVER_STATUS_AUTOCOMMIT + public ushort CapabilityFlagsUpper { get; set; } = 0x0000; + public byte AuthPluginDataLength { get; set; } = 21; + public byte[] Reserved { get; set; } = new byte[10]; + public byte[] AuthPluginDataPart2 { get; set; } = new byte[12]; + public string AuthPluginName { get; set; } = "mysql_native_password"; + + public MySQLHandshakePacket() + { + // Generate random salt for authentication + GenerateAuthPluginData(); + } + + private void GenerateAuthPluginData() + { + using (var rng = RandomNumberGenerator.Create()) + { + rng.GetBytes(AuthPluginDataPart1); + rng.GetBytes(AuthPluginDataPart2); + + // Ensure no null bytes in the salt + for (int i = 0; i < AuthPluginDataPart1.Length; i++) + { + if (AuthPluginDataPart1[i] == 0) + AuthPluginDataPart1[i] = 1; + } + + for (int i = 0; i < AuthPluginDataPart2.Length; i++) + { + if (AuthPluginDataPart2[i] == 0) + AuthPluginDataPart2[i] = 1; + } + } + } + + public byte[] GetFullSalt() + { + var salt = new byte[20]; + Array.Copy(AuthPluginDataPart1, 0, salt, 0, 8); + Array.Copy(AuthPluginDataPart2, 0, salt, 8, 12); + return salt; + } + + public byte[] ToBytes() + { + var serverVersionBytes = Encoding.UTF8.GetBytes(ServerVersion); + var authPluginNameBytes = Encoding.UTF8.GetBytes(AuthPluginName); + + var packetLength = 1 + serverVersionBytes.Length + 1 + 4 + 8 + 1 + 2 + 1 + 2 + 2 + 1 + 10 + 12 + 1 + authPluginNameBytes.Length + 1; + var packet = new byte[4 + packetLength]; // 4 bytes for packet header + + int offset = 0; + + // Packet header + packet[offset++] = (byte)(packetLength & 0xFF); + packet[offset++] = (byte)((packetLength >> 8) & 0xFF); + packet[offset++] = (byte)((packetLength >> 16) & 0xFF); + packet[offset++] = 0x00; // Packet sequence ID + + // Protocol version + packet[offset++] = ProtocolVersion; + + // Server version + Array.Copy(serverVersionBytes, 0, packet, offset, serverVersionBytes.Length); + offset += serverVersionBytes.Length; + packet[offset++] = 0x00; // Null terminator + + // Connection ID + packet[offset++] = (byte)(ConnectionId & 0xFF); + packet[offset++] = (byte)((ConnectionId >> 8) & 0xFF); + packet[offset++] = (byte)((ConnectionId >> 16) & 0xFF); + packet[offset++] = (byte)((ConnectionId >> 24) & 0xFF); + + // Auth plugin data part 1 + Array.Copy(AuthPluginDataPart1, 0, packet, offset, 8); + offset += 8; + + // Filler + packet[offset++] = Filler; + + // Capability flags lower + packet[offset++] = (byte)(CapabilityFlagsLower & 0xFF); + packet[offset++] = (byte)((CapabilityFlagsLower >> 8) & 0xFF); + + // Character set + packet[offset++] = CharacterSet; + + // Status flags + packet[offset++] = (byte)(StatusFlags & 0xFF); + packet[offset++] = (byte)((StatusFlags >> 8) & 0xFF); + + // Capability flags upper + packet[offset++] = (byte)(CapabilityFlagsUpper & 0xFF); + packet[offset++] = (byte)((CapabilityFlagsUpper >> 8) & 0xFF); + + // Auth plugin data length + packet[offset++] = AuthPluginDataLength; + + // Reserved + Array.Copy(Reserved, 0, packet, offset, 10); + offset += 10; + + // Auth plugin data part 2 + Array.Copy(AuthPluginDataPart2, 0, packet, offset, 12); + offset += 12; + packet[offset++] = 0x00; // Null terminator for auth plugin data + + // Auth plugin name + Array.Copy(authPluginNameBytes, 0, packet, offset, authPluginNameBytes.Length); + offset += authPluginNameBytes.Length; + packet[offset++] = 0x00; // Null terminator + + return packet; + } + } +} \ No newline at end of file diff --git a/src/SuperSocket.MySQL/Authentication/MySQLHandshakeResponseFilter.cs b/src/SuperSocket.MySQL/Authentication/MySQLHandshakeResponseFilter.cs new file mode 100644 index 0000000..8ac8fb2 --- /dev/null +++ b/src/SuperSocket.MySQL/Authentication/MySQLHandshakeResponseFilter.cs @@ -0,0 +1,63 @@ +using System; +using System.Buffers; +using SuperSocket.ProtoBase; + +namespace SuperSocket.MySQL.Authentication +{ + /// + /// Filter for parsing MySQL handshake response packets during authentication. + /// + public class MySQLHandshakeResponseFilter : PackagePartsPipelineFilter + { + protected override MySQLHandshakeResponsePacket CreatePackage() + { + return new MySQLHandshakeResponsePacket(); + } + + protected override IPackagePartReader GetFirstPartReader() + { + return MySQLHandshakeResponsePartReader.PackageHeadReader; + } + } + + /// + /// Part reader for MySQL handshake response packets. + /// + public class MySQLHandshakeResponsePartReader : IPackagePartReader + { + public static IPackagePartReader PackageHeadReader { get; private set; } + + static MySQLHandshakeResponsePartReader() + { + PackageHeadReader = new MySQLHandshakeResponsePartReader(); + } + + public bool Process(MySQLHandshakeResponsePacket package, ref SequenceReader reader, out IPackagePartReader nextPartReader, out bool needMoreData) + { + nextPartReader = null; + needMoreData = false; + + try + { + var parsedPacket = MySQLHandshakeResponsePacket.ParseFromSequenceReader(ref reader); + + // Copy parsed data to the package + package.CapabilityFlags = parsedPacket.CapabilityFlags; + package.MaxPacketSize = parsedPacket.MaxPacketSize; + package.CharacterSet = parsedPacket.CharacterSet; + package.Reserved = parsedPacket.Reserved; + package.Username = parsedPacket.Username; + package.AuthResponse = parsedPacket.AuthResponse; + package.Database = parsedPacket.Database; + package.AuthPluginName = parsedPacket.AuthPluginName; + + return true; + } + catch + { + needMoreData = true; + return false; + } + } + } +} \ No newline at end of file diff --git a/src/SuperSocket.MySQL/Authentication/MySQLHandshakeResponsePacket.cs b/src/SuperSocket.MySQL/Authentication/MySQLHandshakeResponsePacket.cs new file mode 100644 index 0000000..59f418b --- /dev/null +++ b/src/SuperSocket.MySQL/Authentication/MySQLHandshakeResponsePacket.cs @@ -0,0 +1,170 @@ +using System; +using System.Buffers; +using System.Text; + +namespace SuperSocket.MySQL.Authentication +{ + /// + /// Represents the handshake response packet from the client. + /// Based on MySQL Protocol specification: + /// https://dev.mysql.com/doc/dev/mysql-server/8.0.11/page_protocol_connection_phase_packets_protocol_handshake_response.html + /// + public class MySQLHandshakeResponsePacket + { + public uint CapabilityFlags { get; set; } + public uint MaxPacketSize { get; set; } + public byte CharacterSet { get; set; } + public byte[] Reserved { get; set; } = new byte[23]; + public string Username { get; set; } = string.Empty; + public byte[] AuthResponse { get; set; } = new byte[0]; + public string Database { get; set; } = string.Empty; + public string AuthPluginName { get; set; } = "mysql_native_password"; + + public static MySQLHandshakeResponsePacket ParseFromBytes(byte[] data, int offset, int length) + { + var packet = new MySQLHandshakeResponsePacket(); + int pos = offset + 4; // Skip packet header + + // Capability flags (4 bytes) + packet.CapabilityFlags = BitConverter.ToUInt32(data, pos); + pos += 4; + + // Max packet size (4 bytes) + packet.MaxPacketSize = BitConverter.ToUInt32(data, pos); + pos += 4; + + // Character set (1 byte) + packet.CharacterSet = data[pos]; + pos += 1; + + // Reserved (23 bytes) + Array.Copy(data, pos, packet.Reserved, 0, 23); + pos += 23; + + // Username (null-terminated string) + int usernameStart = pos; + while (pos < data.Length && data[pos] != 0) + pos++; + packet.Username = Encoding.UTF8.GetString(data, usernameStart, pos - usernameStart); + pos++; // Skip null terminator + + // Auth response length + data + if (pos < data.Length) + { + byte authResponseLength = data[pos]; + pos++; + + if (authResponseLength > 0 && pos + authResponseLength <= data.Length) + { + packet.AuthResponse = new byte[authResponseLength]; + Array.Copy(data, pos, packet.AuthResponse, 0, authResponseLength); + pos += authResponseLength; + } + } + + // Database name (null-terminated string) - optional + if (pos < data.Length) + { + int databaseStart = pos; + while (pos < data.Length && data[pos] != 0) + pos++; + if (pos > databaseStart) + packet.Database = Encoding.UTF8.GetString(data, databaseStart, pos - databaseStart); + pos++; // Skip null terminator + } + + // Auth plugin name (null-terminated string) - optional + if (pos < data.Length) + { + int pluginStart = pos; + while (pos < data.Length && data[pos] != 0) + pos++; + if (pos > pluginStart) + packet.AuthPluginName = Encoding.UTF8.GetString(data, pluginStart, pos - pluginStart); + } + + return packet; + } + + public static MySQLHandshakeResponsePacket ParseFromSequenceReader(ref SequenceReader reader) + { + var packet = new MySQLHandshakeResponsePacket(); + + // Skip packet header (already handled by pipeline) + + // Capability flags (4 bytes) + if (reader.Length < 4) + throw new InvalidOperationException("Cannot read capability flags"); + + var capabilityFlags = 0u; + for (int i = 0; i < 4; i++) + { + if (!reader.TryRead(out byte b)) + throw new InvalidOperationException("Cannot read capability flags"); + capabilityFlags |= (uint)(b << (i * 8)); + } + packet.CapabilityFlags = capabilityFlags; + + // Max packet size (4 bytes) + if (reader.Length < 4) + throw new InvalidOperationException("Cannot read max packet size"); + + var maxPacketSize = 0u; + for (int i = 0; i < 4; i++) + { + if (!reader.TryRead(out byte b)) + throw new InvalidOperationException("Cannot read max packet size"); + maxPacketSize |= (uint)(b << (i * 8)); + } + packet.MaxPacketSize = maxPacketSize; + + // Character set (1 byte) + if (!reader.TryRead(out byte characterSet)) + throw new InvalidOperationException("Cannot read character set"); + packet.CharacterSet = characterSet; + + // Reserved (23 bytes) + packet.Reserved = new byte[23]; + for (int i = 0; i < 23; i++) + { + if (!reader.TryRead(out byte reservedByte)) + throw new InvalidOperationException("Cannot read reserved bytes"); + packet.Reserved[i] = reservedByte; + } + + // Username (null-terminated string) + if (!reader.TryReadTo(out ReadOnlySequence usernameSequence, 0x00, false)) + throw new InvalidOperationException("Cannot read username"); + packet.Username = Encoding.UTF8.GetString(usernameSequence); + reader.Advance(1); // Skip null terminator + + // Auth response length + data + if (reader.TryRead(out byte authResponseLength) && authResponseLength > 0) + { + packet.AuthResponse = new byte[authResponseLength]; + for (int i = 0; i < authResponseLength; i++) + { + if (!reader.TryRead(out byte authByte)) + throw new InvalidOperationException("Cannot read auth response"); + packet.AuthResponse[i] = authByte; + } + } + + // Database name (null-terminated string) - optional + if (reader.TryReadTo(out ReadOnlySequence databaseSequence, 0x00, false)) + { + packet.Database = Encoding.UTF8.GetString(databaseSequence); + reader.Advance(1); // Skip null terminator + } + + // Auth plugin name (null-terminated string) - optional + if (reader.TryReadTo(out ReadOnlySequence pluginSequence, 0x00, false)) + { + packet.AuthPluginName = Encoding.UTF8.GetString(pluginSequence); + reader.Advance(1); // Skip null terminator + } + + return packet; + } + } +} \ No newline at end of file diff --git a/src/SuperSocket.MySQL/Authentication/MySQLSession.cs b/src/SuperSocket.MySQL/Authentication/MySQLSession.cs new file mode 100644 index 0000000..a1443e6 --- /dev/null +++ b/src/SuperSocket.MySQL/Authentication/MySQLSession.cs @@ -0,0 +1,113 @@ +using System; +using System.Threading.Tasks; +using SuperSocket.Server; +using SuperSocket.Channel; +using Microsoft.Extensions.Logging; + +namespace SuperSocket.MySQL.Authentication +{ + /// + /// MySQL session that handles authentication before allowing query processing. + /// + public class MySQLSession : AppSession + { + private readonly MySQLAuthenticationHandler _authHandler; + private MySQLHandshakePacket _handshakePacket; + private bool _isAuthenticated = false; + + public MySQLSession() + { + _authHandler = new MySQLAuthenticationHandler(); + } + + protected override async ValueTask OnSessionConnectedAsync() + { + await base.OnSessionConnectedAsync(); + + // Send handshake packet immediately upon connection + _handshakePacket = _authHandler.CreateHandshake(); + var handshakeBytes = _handshakePacket.ToBytes(); + + await SendAsync(handshakeBytes); + + Logger?.LogInformation($"Sent handshake to connection {_handshakePacket.ConnectionId}"); + } + + public async Task HandleAuthenticationAsync(MySQLHandshakeResponsePacket response) + { + try + { + var salt = _handshakePacket?.GetFullSalt(); + if (salt == null) + { + Logger?.LogWarning("No handshake salt available for authentication"); + await SendErrorAsync(1045, "Access denied"); + return false; + } + + var isValid = _authHandler.ValidateCredentials(response, salt); + + if (isValid) + { + _isAuthenticated = true; + var okPacket = _authHandler.CreateOkPacket(); + await SendAsync(okPacket); + + Logger?.LogInformation($"User '{response.Username}' authenticated successfully"); + return true; + } + else + { + await SendErrorAsync(1045, "Access denied for user '" + response.Username + "'"); + Logger?.LogWarning($"Authentication failed for user '{response.Username}'"); + return false; + } + } + catch (Exception ex) + { + Logger?.LogError(ex, "Error during authentication"); + await SendErrorAsync(1045, "Authentication error"); + return false; + } + } + + public bool IsAuthenticated => _isAuthenticated; + + private async Task SendErrorAsync(ushort errorCode, string message) + { + var errorPacket = _authHandler.CreateErrorPacket(errorCode, message); + await SendAsync(errorPacket); + } + + private async Task SendAsync(byte[] data) + { + await Channel.SendAsync(new ReadOnlyMemory(data)); + } + + protected virtual async ValueTask OnPackageReceived(MySQLHandshakeResponsePacket package) + { + if (!IsAuthenticated) + { + // This is the authentication response + var success = await HandleAuthenticationAsync(package); + if (!success) + { + // Authentication failed, close the connection + await CloseAsync(); + } + } + else + { + // This should be a command packet after authentication + // In a real implementation, you'd switch filters after authentication + Logger?.LogWarning("Received data after authentication - command processing not implemented"); + } + } + + protected override async ValueTask OnSessionClosedAsync(EventArgs e) + { + Logger?.LogInformation($"MySQL session closed"); + await base.OnSessionClosedAsync(e); + } + } +} \ No newline at end of file diff --git a/src/SuperSocket.MySQL/Integration/MySQLIntegration.cs b/src/SuperSocket.MySQL/Integration/MySQLIntegration.cs new file mode 100644 index 0000000..956153c --- /dev/null +++ b/src/SuperSocket.MySQL/Integration/MySQLIntegration.cs @@ -0,0 +1,127 @@ +using System; +using System.Threading.Tasks; +using SuperSocket.MySQL.Authentication; + +namespace SuperSocket.MySQL.Integration +{ + /// + /// Simple integration example showing how to integrate MySQL authentication + /// with existing SuperSocket.MySQL QueryResultFilter. + /// + public class IntegratedMySQLSession : MySQLSession + { + private bool _switchedToQueryFilter = false; + + protected override async ValueTask OnPackageReceived(MySQLHandshakeResponsePacket package) + { + if (!IsAuthenticated) + { + // Handle authentication + var success = await HandleAuthenticationAsync(package); + if (!success) + { + await CloseAsync(); + return; + } + + System.Console.WriteLine($"Authentication successful for user: {package.Username}"); + + // After successful authentication, in a complete implementation + // you would switch the packet filter to handle MySQL commands + // For now, we'll just mark that we're ready for queries + _switchedToQueryFilter = true; + + System.Console.WriteLine("Ready to process MySQL queries"); + } + else + { + // Post-authentication: this would be query processing + System.Console.WriteLine("Received query data (processing not implemented)"); + + // In a real implementation, you would: + // 1. Parse the MySQL command packet + // 2. Execute the SQL query + // 3. Return results using QueryResult and QueryResultFilter + + // For demonstration, just send an OK packet + var handler = new MySQLAuthenticationHandler(); + var okPacket = handler.CreateOkPacket(); + await Channel.SendAsync(new ReadOnlyMemory(okPacket)); + } + } + + /// + /// This method shows how you might integrate with the existing QueryResult system + /// after authentication is complete. + /// + private async Task ProcessMySQLQuery(byte[] queryData) + { + try + { + // Parse the query packet (COM_QUERY or others) + // Execute the query using your database backend + // Create a QueryResult with the results + + var queryResult = new QueryResult + { + ErrorCode = 0, + ErrorMessage = null, + Columns = new[] { "id", "name", "email" }, + Rows = new System.Collections.Generic.List + { + new[] { "1", "John Doe", "john@example.com" }, + new[] { "2", "Jane Smith", "jane@example.com" } + } + }; + + // Send the result back to the client + // Note: This would require implementing a QueryResult serializer + // that follows MySQL protocol format + System.Console.WriteLine($"Query result: {queryResult.Rows.Count} rows"); + } + catch (Exception ex) + { + System.Console.WriteLine($"Error processing MySQL query: {ex.Message}"); + + // Send error packet + var handler = new MySQLAuthenticationHandler(); + var errorPacket = handler.CreateErrorPacket(1064, "Query execution error"); + await Channel.SendAsync(new ReadOnlyMemory(errorPacket)); + } + } + } + + /// + /// Example showing how to create a complete MySQL server with authentication. + /// + public static class MySQLServerExample + { + public static void ShowIntegrationExample() + { + System.Console.WriteLine("=== MySQL Authentication Integration Example ===\n"); + + System.Console.WriteLine("1. Authentication Flow:"); + System.Console.WriteLine(" - Client connects to server"); + System.Console.WriteLine(" - Server sends handshake packet with challenge"); + System.Console.WriteLine(" - Client responds with credentials"); + System.Console.WriteLine(" - Server validates and sends OK/ERR"); + + System.Console.WriteLine("\n2. Query Processing (after authentication):"); + System.Console.WriteLine(" - Client sends COM_QUERY packets"); + System.Console.WriteLine(" - Server processes SQL and returns QueryResult"); + System.Console.WriteLine(" - Results formatted according to MySQL protocol"); + + System.Console.WriteLine("\n3. Integration Points:"); + System.Console.WriteLine(" - MySQLSession handles authentication"); + System.Console.WriteLine(" - QueryResultFilter handles query responses"); + System.Console.WriteLine(" - Custom session bridges authentication → queries"); + + System.Console.WriteLine("\n4. Example Usage:"); + System.Console.WriteLine(" mysql -h 127.0.0.1 -u test -p"); + System.Console.WriteLine(" Password: test"); + System.Console.WriteLine(" mysql> SELECT * FROM users;"); + + System.Console.WriteLine("\n✓ Integration example complete"); + } + } +} \ No newline at end of file diff --git a/src/SuperSocket.MySQL/Program.cs b/src/SuperSocket.MySQL/Program.cs new file mode 100644 index 0000000..21d6580 --- /dev/null +++ b/src/SuperSocket.MySQL/Program.cs @@ -0,0 +1,21 @@ +using SuperSocket.MySQL.Tests; +using SuperSocket.MySQL.Integration; + +class Program +{ + static void Main(string[] args) + { + System.Console.WriteLine("SuperSocket.MySQL Authentication Implementation\n"); + + // Run basic authentication tests + AuthenticationTest.RunAllTests(); + + System.Console.WriteLine("\n" + new string('=', 50) + "\n"); + + // Show integration example + MySQLServerExample.ShowIntegrationExample(); + + System.Console.WriteLine("\nPress any key to exit..."); + System.Console.ReadKey(); + } +} \ No newline at end of file diff --git a/src/SuperSocket.MySQL/SuperSocket.MySQL.csproj b/src/SuperSocket.MySQL/SuperSocket.MySQL.csproj index ccac898..c478ccb 100644 --- a/src/SuperSocket.MySQL/SuperSocket.MySQL.csproj +++ b/src/SuperSocket.MySQL/SuperSocket.MySQL.csproj @@ -1,8 +1,10 @@ - netstandard2.1 + net6.0 + Exe - + + diff --git a/src/SuperSocket.MySQL/Tests/AuthenticationTest.cs b/src/SuperSocket.MySQL/Tests/AuthenticationTest.cs new file mode 100644 index 0000000..f205436 --- /dev/null +++ b/src/SuperSocket.MySQL/Tests/AuthenticationTest.cs @@ -0,0 +1,95 @@ +using System; +using System.Text; +using SuperSocket.MySQL.Authentication; + +namespace SuperSocket.MySQL.Tests +{ + /// + /// Simple test to validate MySQL authentication components. + /// + public class AuthenticationTest + { + public static void TestHandshakePacket() + { + System.Console.WriteLine("Testing MySQL Handshake Packet..."); + + var handshake = new MySQLHandshakePacket(); + var handshakeBytes = handshake.ToBytes(); + + System.Console.WriteLine($"Handshake packet size: {handshakeBytes.Length} bytes"); + System.Console.WriteLine($"Connection ID: {handshake.ConnectionId}"); + System.Console.WriteLine($"Server Version: {handshake.ServerVersion}"); + System.Console.WriteLine($"Protocol Version: {handshake.ProtocolVersion}"); + System.Console.WriteLine($"Salt length: {handshake.GetFullSalt().Length}"); + + System.Console.WriteLine("✓ Handshake packet created successfully"); + } + + public static void TestPasswordScrambling() + { + System.Console.WriteLine("\nTesting MySQL Password Scrambling..."); + + var handler = new MySQLAuthenticationHandler(); + var handshake = handler.CreateHandshake(); + var salt = handshake.GetFullSalt(); + + // Test with valid credentials (test/test) + var response = new MySQLHandshakeResponsePacket + { + Username = "test", + AuthResponse = handler.ScramblePasswordForTest("test", salt) // We need to expose this for testing + }; + + System.Console.WriteLine($"Username: {response.Username}"); + System.Console.WriteLine($"Auth response length: {response.AuthResponse?.Length ?? 0}"); + + System.Console.WriteLine("✓ Password scrambling test completed"); + } + + public static void TestOkErrorPackets() + { + System.Console.WriteLine("\nTesting OK and Error Packets..."); + + var handler = new MySQLAuthenticationHandler(); + + var okPacket = handler.CreateOkPacket(); + System.Console.WriteLine($"OK packet size: {okPacket.Length} bytes"); + + var errorPacket = handler.CreateErrorPacket(1045, "Access denied"); + System.Console.WriteLine($"Error packet size: {errorPacket.Length} bytes"); + + System.Console.WriteLine("✓ OK and Error packets created successfully"); + } + + public static void RunAllTests() + { + System.Console.WriteLine("=== MySQL Authentication Component Tests ===\n"); + + try + { + TestHandshakePacket(); + TestPasswordScrambling(); + TestOkErrorPackets(); + + System.Console.WriteLine("\n✅ All tests passed!"); + } + catch (Exception ex) + { + System.Console.WriteLine($"\n❌ Test failed: {ex.Message}"); + System.Console.WriteLine(ex.StackTrace); + } + } + } +} + +// Extension to expose ScramblePassword for testing +namespace SuperSocket.MySQL.Authentication +{ + public partial class MySQLAuthenticationHandler + { + public byte[] ScramblePasswordForTest(string password, byte[] salt) + { + return ScramblePassword(password, salt); + } + } +} \ No newline at end of file