Distinguishing completion callbacks (IoCompletionCallback)

noname 141 Reputation points
2021-10-31T15:36:56.87+00:00

Hi all,

I am writing a generic proxy (protocol agnostic) in C++ using the thread pool functions (https://learn.microsoft.com/en-us/windows/win32/procthread/using-the-thread-pool-functions).

The proxy might receive data from both, the client and the server, maybe also simultaneously:

  • If the proxy receives data from the client, it has to forward it to the server.
  • If the proxy receives data from the server, it has to forward it to the client.

As the proxy doesn't know the protocol used, it just starts asynchronous reads on both sockets. When data is received from one socket, it sends it to the other socket.

Each socket has its own callback function (IoCompletionCallback) and TP_IO, I am using CreateThreadpoolIo() (https://learn.microsoft.com/en-us/windows/win32/api/threadpoolapiset/nf-threadpoolapiset-createthreadpoolio) and StartThreadpoolIo() (https://learn.microsoft.com/en-us/windows/win32/api/threadpoolapiset/nf-threadpoolapiset-startthreadpoolio).

Now my question is, when one of the sockets receives a completion callback (IoCompletionCallback), how can I know which operation triggered the completion callback? (it might be related to a receive [WSARecv()] or to a send [WSASend()]).

I have thought that I could use two different completion callbacks and two TP_IO per socket. Would this work or can I have only one TP_IO per socket (handle)?

Is there an easier approach for distinguishing the operation which triggered the completion callback?

Thank you in advance.

Windows API - Win32
Windows API - Win32
A core set of Windows application programming interfaces (APIs) for desktop and server applications. Previously known as Win32 API.
2,412 questions
C++
C++
A high-level, general-purpose programming language, created as an extension of the C programming language, that has object-oriented, generic, and functional features in addition to facilities for low-level memory manipulation.
3,519 questions
{count} votes

Accepted answer
  1. noname 141 Reputation points
    2021-11-03T07:02:43.097+00:00

    The solution is to use an extended OVERLAPPED structure, where the operation type (send or receive) is saved.
    Additionally, it is not possible to reuse an OVERLAPPED structure when there is another operation in-progress using that OVERLAPPED structure.
    I have taken the idea from ctThreadIocp.hpp

    1 person found this answer helpful.
    0 comments No comments

1 additional answer

Sort by: Most helpful
  1. noname 141 Reputation points
    2021-11-01T06:46:18.153+00:00

    (I am answering here because my answer is too long)

    I am passing a pointer to an instance of my socket class in CreateThreadpoolIo(). I am reusing the same PTP_IO for all the socket operations on the same socket.

    See an extract of my socket class below. Each socket instance has its own PTP_IO (line 29). In the method socket::init() I call CreateThreadpoolIo() passing as context the this pointer (line 62) and in the static method socket::io_completion_callback() I cast the context to a socket (line 84).

    If I start two asynchronous operations on the same socket object (for example, one WSARecv() and one WSASend()), I cannot distinguish in socket::io_completion_callback() which completion it is.

    That's why I thought that one possibility would be to have two PTP_IO per socket object (one for all the WSARecv() operations and one for all the WSASend() operations).

    My socket class looks like follows:

    ///////////////////////
    // socket.hpp        //
    ///////////////////////
    class socket {
      public:
        ...
        // Initialize socket.
        bool init(int domain,
                  PTP_WIN32_IO_CALLBACK callback,
                  PTP_CALLBACK_ENVIRON callbackenv = nullptr);
    
        ...
    
        // Receive.
        DWORD receive(void* buf, size_t len, DWORD flags, DWORD& received);
    
        // Send.
        DWORD send(const void* buf, size_t len, DWORD flags, DWORD& sent);
    
      private:
        // Socket handle.
        SOCKET _M_sock = INVALID_SOCKET;
    
        // Overlapped structure.
        OVERLAPPED _M_overlapped;
    
        // I/O completion object.
        // Here I use the same PTP_IO for all the operations on the same socket.
        PTP_IO _M_io = nullptr;
    
        // I/O completion callback.
        static void CALLBACK io_completion_callback(PTP_CALLBACK_INSTANCE instance,
                                                    void* context,
                                                    void* overlapped,
                                                    ULONG result,
                                                    ULONG_PTR transferred,
                                                    PTP_IO io);
    
        ...
    };
    
    ///////////////////////
    // socket.cpp        //
    ///////////////////////
    bool socket::init(int domain,
                      PTP_WIN32_IO_CALLBACK callback,
                      PTP_CALLBACK_ENVIRON callbackenv)
    {
      // Create non-overlapped socket.
      _M_sock = ::WSASocket(domain,
                            SOCK_STREAM,
                            0,
                            nullptr,
                            0,
                            WSA_FLAG_OVERLAPPED);
    
      // If the socket could be created...
      if (_M_sock != INVALID_SOCKET) {
        // Create I/O completion object.
        _M_io = ::CreateThreadpoolIo(reinterpret_cast<HANDLE>(_M_sock),
                                     io_completion_callback,
                                     this,
                                     callbackenv);
    
        // If the I/O completion object could be created...
        if (_M_io) {
          // Clear overlapped structure.
          memset(&_M_overlapped, 0, sizeof(OVERLAPPED));
    
          return true;
        }
      }
    
      return false;
    }
    
    void CALLBACK socket::io_completion_callback(PTP_CALLBACK_INSTANCE instance,
                                                 void* context,
                                                 void* overlapped,
                                                 ULONG result,
                                                 ULONG_PTR transferred,
                                                 PTP_IO io)
    {
      socket* const sock = static_cast<socket*>(context);
    
      // Here I check `result` (error code) and `transferred`.
    }
    
    0 comments No comments