CollabVM-Agent-Windows/CollabVMAgent/VirtIOSerial.cs
2023-12-10 15:21:13 -05:00

252 lines
10 KiB
C#

using CollabVMAgent.Protocol;
using MsgPack.Serialization;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.ConstrainedExecution;
using System.Runtime.InteropServices;
using System.Security;
using System.Text;
using System.Threading;
namespace CollabVMAgent
{
public class VirtIOSerial : IDisposable
{
MessagePackSerializer serializer = MessagePackSerializer.Get<ProtocolMessage>();
const int READ_INTERVAL = 3000;
const int COPY_BUFFER_SIZE = 4096;
static Guid GUID_VIOSERIAL_PORT = new Guid(0x6fde7521, 0x1b65, 0x48ae, 0xb6, 0x28, 0x80, 0xbe, 0x62, 0x1, 0x60, 0x26);
static readonly IntPtr INVALID_HANDLE_VALUE = new IntPtr(-1);
string DevicePath;
IntPtr viohnd;
Thread readThread;
CancellationTokenSource readCancelSrc;
CancellationToken readCancel;
bool disposed = false;
public event EventHandler<TypedEventArgs<byte[]>> Data;
public VirtIOSerial()
{
DevicePath = GetDevicePath();
viohnd = CreateFile(DevicePath, FileAccess.ReadWrite, 0, IntPtr.Zero, FileMode.Open, FileAttributes.Normal, IntPtr.Zero);
readCancelSrc = new CancellationTokenSource();
readCancel = readCancelSrc.Token;
readThread = new Thread(new ThreadStart(ReadLoop));
readThread.Start();
}
public bool Write(byte[] data)
{
if (disposed) throw new ObjectDisposedException(nameof(VirtIOSerial));
uint wrote;
return WriteFile(viohnd, data, (uint)data.Length, out wrote, IntPtr.Zero);
}
public bool WriteMsg(ProtocolMessage msg)
{
byte[] message;
using (var ms = new MemoryStream())
{
serializer.Pack(ms, msg);
message = ms.ToArray();
}
byte[] header = BitConverter.GetBytes((UInt32)message.Length);
if (!Write(header)) return false;
return Write(message);
}
void ReadLoop()
{
byte[] sizebuf = new byte[4];
while (!readCancel.IsCancellationRequested)
{
// Get the length of the message
uint read = 0;
while (read != 4)
{
if (readCancel.IsCancellationRequested) return;
Thread.Sleep(READ_INTERVAL);
ReadFile(viohnd, sizebuf, 4, out read, IntPtr.Zero);
var err = Marshal.GetLastWin32Error();
if (err != 0)
{
#if DEBUG
Console.WriteLine($"Got error {err} while trying to read from serial port");
#endif
switch (err)
{
case 6:
CloseHandle(viohnd);
viohnd = CreateFile(DevicePath, FileAccess.ReadWrite, 0, IntPtr.Zero, FileMode.Open, FileAttributes.Normal, IntPtr.Zero);
break;
}
}
}
uint size = BitConverter.ToUInt32(sizebuf, 0);
#if DEBUG
Console.WriteLine($"Reported payload size: {size}");
#endif
if (size == 0)
{
continue;
}
// Read the message
uint position = 0;
byte[] buf = new byte[COPY_BUFFER_SIZE];
MemoryStream ms = new MemoryStream();
while (position < size)
{
if (readCancel.IsCancellationRequested)
{
ms.Dispose();
return;
}
ReadFile(viohnd, buf, COPY_BUFFER_SIZE, out read, IntPtr.Zero);
if (position < size && read == 0) throw new Exception("Stream ended prematurely");
ms.Write(buf, 0, (int)read);
#if DEBUG
Console.WriteLine($"Bytes read: {read}, Position: {position}, Size: {size}");
#endif
position += read;
}
byte[] payload = ms.ToArray();
if (this.WriteMsg(new ProtocolMessage
{
Operation = ProtocolOperation.ACK
}))
Console.WriteLine("Wrote ACK to serial.");
ms.Dispose();
Data.Invoke(this, new TypedEventArgs<byte[]>(payload));
}
}
static string GetDevicePath()
{
IntPtr HardwareDeviceInfo;
SP_DEVICE_INTERFACE_DATA DeviceInterfaceData = new SP_DEVICE_INTERFACE_DATA();
bool bResult;
UInt32 Length = 0;
UInt32 RequiredLength = 0;
IntPtr DeviceInterfaceDetailData;
HardwareDeviceInfo = SetupDiGetClassDevs(ref GUID_VIOSERIAL_PORT, IntPtr.Zero, IntPtr.Zero, (uint)(DIGCF.DIGCF_PRESENT | DIGCF.DIGCF_DEVICEINTERFACE));
if (HardwareDeviceInfo.Equals(INVALID_HANDLE_VALUE))
{
throw new Exception("Cannot get class devices.");
}
DeviceInterfaceData.cbSize = Marshal.SizeOf(typeof(SP_DEVICE_INTERFACE_DATA));
bResult = SetupDiEnumDeviceInterfaces(HardwareDeviceInfo, IntPtr.Zero, ref GUID_VIOSERIAL_PORT, 0, ref DeviceInterfaceData);
if (bResult == false)
{
SetupDiDestroyDeviceInfoList(HardwareDeviceInfo);
throw new Exception("Cannot get enumerate device interfaces.");
}
SetupDiGetDeviceInterfaceDetail(HardwareDeviceInfo, ref DeviceInterfaceData, IntPtr.Zero, 0, ref RequiredLength, IntPtr.Zero);
// Skidded from stackoverflow
int structSize = Marshal.SystemDefaultCharSize;
if (IntPtr.Size == 8)
structSize += 6; // 64-bit systems, with 8-byte packing
else
structSize += 4; // 32-bit systems, with byte packing
DeviceInterfaceDetailData = Marshal.AllocHGlobal((int)RequiredLength + structSize);
if (DeviceInterfaceDetailData.Equals(IntPtr.Zero))
{
SetupDiDestroyDeviceInfoList(HardwareDeviceInfo);
throw new Exception("Cannot allocate memory.");
}
Marshal.WriteInt32(DeviceInterfaceDetailData, (int)structSize);
bResult = SetupDiGetDeviceInterfaceDetail(HardwareDeviceInfo, ref DeviceInterfaceData, DeviceInterfaceDetailData, RequiredLength, ref RequiredLength, IntPtr.Zero);
if (bResult == false)
{
SetupDiDestroyDeviceInfoList(HardwareDeviceInfo);
Marshal.FreeHGlobal(DeviceInterfaceDetailData);
throw new Exception("Cannot get device interface details.");
}
string DevicePath = Marshal.PtrToStringUni(new IntPtr(DeviceInterfaceDetailData.ToInt64() + 4));
Marshal.FreeHGlobal(DeviceInterfaceDetailData);
return DevicePath;
}
[DllImport("setupapi.dll", CharSet = CharSet.Auto, SetLastError = true)]
static extern IntPtr SetupDiGetClassDevs(
ref Guid ClassGuid,
IntPtr Enumerator,
IntPtr hwndParent,
uint Flags
);
[DllImport(@"setupapi.dll", CharSet = CharSet.Auto, SetLastError = true)]
static extern Boolean SetupDiEnumDeviceInterfaces(
IntPtr hDevInfo,
IntPtr devInfo,
ref Guid interfaceClassGuid,
UInt32 memberIndex,
ref SP_DEVICE_INTERFACE_DATA deviceInterfaceData
);
[DllImport(@"setupapi.dll", CharSet = CharSet.Auto, SetLastError = true)]
static extern Boolean SetupDiGetDeviceInterfaceDetail(
IntPtr hDevInfo,
ref SP_DEVICE_INTERFACE_DATA deviceInterfaceData,
IntPtr deviceInterfaceDetailData,
UInt32 deviceInterfaceDetailDataSize,
ref UInt32 requiredSize,
IntPtr deviceInfoData
);
[DllImport("setupapi.dll", SetLastError = true)]
static extern bool SetupDiDestroyDeviceInfoList(IntPtr DeviceInfoSet);
[DllImport("kernel32.dll", CharSet = CharSet.Auto, SetLastError = true)]
static extern IntPtr CreateFile(
[MarshalAs(UnmanagedType.LPTStr)] string filename,
[MarshalAs(UnmanagedType.U4)] FileAccess access,
[MarshalAs(UnmanagedType.U4)] uint share,
IntPtr securityAttributes, // optional SECURITY_ATTRIBUTES struct or IntPtr.Zero
[MarshalAs(UnmanagedType.U4)] FileMode creationDisposition,
[MarshalAs(UnmanagedType.U4)] FileAttributes flagsAndAttributes,
IntPtr templateFile
);
[DllImport("kernel32.dll", SetLastError = true)]
static extern bool ReadFile(IntPtr hFile, [Out] byte[] lpBuffer, uint nNumberOfBytesToRead, out uint lpNumberOfBytesRead, IntPtr lpOverlapped);
[DllImport("kernel32.dll", SetLastError = true)]
static extern bool WriteFile(IntPtr hFile, byte[] lpBuffer, uint nNumberOfBytesToWrite, out uint lpNumberOfBytesWritten, IntPtr lpOverlapped);
[DllImport("kernel32.dll", SetLastError = true)]
[ReliabilityContract(Consistency.WillNotCorruptState, Cer.Success)]
[SuppressUnmanagedCodeSecurity]
[return: MarshalAs(UnmanagedType.Bool)]
static extern bool CloseHandle(IntPtr hObject);
public void Dispose()
{
this.disposed = true;
readCancelSrc.Cancel();
readCancelSrc.Dispose();
CloseHandle(viohnd);
}
}
[StructLayout(LayoutKind.Sequential)]
struct SP_DEVICE_INTERFACE_DATA
{
public Int32 cbSize;
public Guid interfaceClassGuid;
public Int32 flags;
private UIntPtr reserved;
}
[Flags]
public enum DIGCF : uint
{
DIGCF_DEFAULT = 0x00000001, // only valid with DIGCF_DEVICEINTERFACE
DIGCF_PRESENT = 0x00000002,
DIGCF_ALLCLASSES = 0x00000004,
DIGCF_PROFILE = 0x00000008,
DIGCF_DEVICEINTERFACE = 0x00000010,
}
}