1 /*
2  * Geario - A cross-platform abstraction library with asynchronous I/O.
3  *
4  * Copyright (C) 2021-2022 Kerisy.com
5  *
6  * Website: https://www.kerisy.com
7  *
8  * Licensed under the Apache-2.0 License.
9  *
10  */
11 
12 module geario.event.selector.IOCP;
13 
14 // dfmt off
15 version (HAVE_IOCP) : 
16 // dfmt on
17 
18 import geario.event.selector.Selector;
19 import geario.net.channel.Types;
20 import geario.net.channel;
21 import geario.event.timer;
22 import geario.logging;
23 import geario.system.Error;
24 import geario.net.channel.iocp.AbstractStream;
25 import core.sys.windows.windows;
26 import std.conv;
27 import std.socket;
28 import geario.util.worker;
29 import std.container : DList;
30 
31 
32 /**
33  * 
34  */
35 class AbstractSelector : Selector {
36 
37     this(size_t number, size_t divider, Worker worker = null, size_t maxChannels = 1500) {
38         super(number, divider, worker, maxChannels);
39         _iocpHandle = CreateIoCompletionPort(INVALID_HANDLE_VALUE, null, 0, 0);
40         if (_iocpHandle is null)
41             log.error("CreateIoCompletionPort failed: %d\n", GetLastError());
42         _timer.init();
43         _stopEvent = CreateEvent(NULL, TRUE, FALSE, NULL);
44     }
45 
46     ~this() {
47         // import std.socket;
48         // std.socket.close(_iocpHandle);
49     }
50 
51     override bool Register(AbstractChannel channel) {
52         super.Register(channel);
53 
54         ChannelType ct = channel.Type;
55         auto fd = channel.handle;
56         version (GEAR_IO_DEBUG)
57             log.trace("register, channel(fd=%d, type=%s)", fd, ct);
58 
59         if (ct == ChannelType.Timer) {
60             AbstractTimer timerChannel = cast(AbstractTimer) channel;
61             assert(timerChannel !is null);
62             if (!timerChannel.setTimerOut())
63                 return false;
64             _timer.timeWheel().addNewTimer(timerChannel.timer, timerChannel.wheelSize());
65         } else if (ct == ChannelType.TCP
66                 || ct == ChannelType.Accept || ct == ChannelType.UDP) {
67             version (GEAR_IO_DEBUG)
68                 Trace("Run CreateIoCompletionPort on socket: ", fd);
69 
70             // _event.SetNext(channel);
71             CreateIoCompletionPort(cast(HANDLE) fd, _iocpHandle,
72                     cast(size_t)(cast(void*) channel), 0);
73 
74             //cast(AbstractStream)channel)
75         } else {
76             log.warn("Can't register a channel: %s", ct);
77         }
78 
79         auto stream = cast(AbstractStream)channel;
80         if (stream !is null) {
81             stream.BeginRead();
82         }
83 
84         return true;
85     }
86 
87     override bool Deregister(AbstractChannel channel) {
88         // FIXME: Needing refactor or cleanup -@Administrator at 8/28/2018, 3:28:18 PM
89         // https://stackoverflow.com/questions/6573218/removing-a-handle-from-a-i-o-completion-port-and-other-questions-about-iocp
90         version(GEAR_IO_DEBUG) 
91         log.trace("deregister (fd=%d)", channel.handle);
92 
93 
94 
95         // IocpContext _data;
96         // _data.channel = channel;
97         // _data.operation = IocpOperation.close;
98         // PostQueuedCompletionStatus(_iocpHandle, 0, 0, &_data.overlapped);
99         //(cast(AbstractStream)channel).stopAction();
100         //WaitForSingleObject
101         return super.Deregister(channel);
102     }
103 
104     // void weakUp() {
105     //     IocpContext _data;
106     //     // _data.channel = _event;
107     //     _data.operation = IocpOperation.event;
108 
109     //     // PostQueuedCompletionStatus(_iocpHandle, 0, 0, &_data.overlapped);
110     //     PostQueuedCompletionStatus(_iocpHandle, 0, 0, null);
111     // }
112 
113     override void OnLoop(long timeout = -1) {
114         _timer.init();
115         super.OnLoop(timeout);
116     }
117 
118     protected override int DoSelect(long t) {
119         auto timeout = _timer.doWheel();
120         OVERLAPPED* overlapped;
121         ULONG_PTR key = 0;
122         DWORD bytes = 0;
123         IocpContext* ev;
124 
125         while( WAIT_OBJECT_0 != WaitForSingleObject(_stopEvent , 0) && !IsStopping()) {
126             // https://docs.microsoft.com/zh-cn/windows/win32/api/ioapiset/nf-ioapiset-getqueuedcompletionstatus
127             const int ret = GetQueuedCompletionStatus(_iocpHandle, &bytes, &key,
128                     &overlapped, INFINITE);
129             
130             ev = cast(IocpContext*) overlapped;
131             // ev = cast(IocpContext *)( cast(PCHAR)(overlapped) - cast(ULONG_PTR)(&(cast(IocpContext*)0).overlapped));
132             if (ret == 0) {
133                 DWORD dwErr = GetLastError();
134                 if (WAIT_TIMEOUT == dwErr) {
135                     continue;
136                 } else {
137                     assert(ev !is null, "The IocpContext is null");
138                     AbstractChannel channel = ev.channel;
139                     if (channel !is null && !channel.IsClosed()) {
140                         channel.Close();
141                     }
142                     continue;
143                 }
144             } else if (ev is null || ev.channel is null) {
145                version(GEAR_IO_DEBUG) log.warn("The ev is null or ev.watche is null. isStopping: %s", IsStopping());
146             } else {
147                 if (0 == bytes && (ev.operation == IocpOperation.read || ev.operation == IocpOperation.write)) {
148                     AbstractChannel channel = ev.channel;
149                     if (channel !is null && !channel.IsClosed()) {
150                         channel.Close();
151                     }
152                     continue;
153                 } else {
154                     HandleChannelEvent(ev.operation, ev.channel, bytes);
155                 }
156             }
157         }
158 
159         return 0;
160     }
161 
162     private void HandleChannelEvent(IocpOperation op, AbstractChannel channel, DWORD bytes) {
163 
164         version (GEAR_IO_DEBUG)
165             log.info("ev.operation: %s, fd=%d", op, channel.handle);
166 
167         switch (op) {
168             case IocpOperation.accept:
169                 channel.OnRead();
170                 break;
171             case IocpOperation.connect:
172                 OnSocketRead(channel, 0);
173                 (cast(AbstractStream)channel).BeginRead();
174                 break;
175             case IocpOperation.read:
176                 OnSocketRead(channel, bytes);
177                 break;
178             case IocpOperation.write:
179                 OnSocketWrite(channel, bytes);
180                 break;
181             case IocpOperation.event:
182                 channel.OnRead();
183                 break;
184             case IocpOperation.close:
185                 break;
186             default:
187                 log.warn("unsupported operation type: ", op);
188             break;
189         }
190     }
191 
192     override void Stop() {
193         super.Stop();
194         // weakUp();
195         PostQueuedCompletionStatus(_iocpHandle, 0, 0, null);
196     }
197 
198     void HandleTimer() {
199 
200     }
201 
202     // override void Dispose() {
203 
204     // }
205 
206     private void OnSocketRead(AbstractChannel channel, size_t len) {
207         debug if (channel is null) {
208             log.warn("channel is null");
209             return;
210         }
211 
212         if (channel is null)
213         {
214             log.warn("channel is null");
215             return;
216         }
217 
218         // (cast(AbstractStream)channel).setBusyWrite(false);
219 
220         if (len == 0 || channel.IsClosed) {
221             version (GEAR_IO_DEBUG)
222                log.info("channel [fd=%d] closed. isClosed: %s, len: %d", channel.handle, channel.isClosed, len);
223             //channel.Close();
224             return;
225         }
226 
227         AbstractSocketChannel socketChannel = cast(AbstractSocketChannel) channel;
228         // assert(socketChannel !is null, "The type of channel is: " ~ typeid(channel).name);
229         if (socketChannel is null) {
230             log.warn("The channel socket is null: ");
231         } else {
232             socketChannel.setRead(len);
233             channel.OnRead();
234         }
235     }
236 
237     private void OnSocketWrite(AbstractChannel channel, size_t len) {
238         debug if (channel is null) {
239             log.warn("channel is null");
240             return;
241         }
242         AbstractStream client = cast(AbstractStream) channel;
243         // assert(client !is null, "The type of channel is: " ~ typeid(channel).name);
244         if (client is null) {
245             log.warn("The channel socket is null: ");
246             return;
247         }
248         client.OnWriteDone(len); // Notify the client about how many bytes actually sent.
249     }
250 
251 
252 private:
253     HANDLE _iocpHandle;
254     CustomTimer _timer;
255     HANDLE _stopEvent;
256 }