From 3bb03bc5d095cf51c8605521fcc2ce8cad36995f Mon Sep 17 00:00:00 2001 From: Fiona Klute Date: Tue, 25 Mar 2025 13:29:04 +0100 Subject: [PATCH] gi.events._Selector.get_map(): look up file objects by file descriptor File objects and their underlying file descriptors must be treated as equivalent for lookup, or one may incorrectly be treated as not registered when the other was used for the registration, leading to bugs. Signed-off-by: Fiona Klute Upstream: https://gitlab.gnome.org/GNOME/pygobject/-/commit/3bb03bc5d095cf51c8605521fcc2ce8cad36995f --- gi/events.py | 36 +++++++++++++++++++++----- tests/test_events.py | 61 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 7 deletions(-) diff --git a/gi/events.py b/gi/events.py index 9f00059ec..980cae0fe 100644 --- a/gi/events.py +++ b/gi/events.py @@ -28,6 +28,7 @@ import threading import selectors import weakref import warnings +from collections.abc import Mapping from contextlib import contextmanager from . import _ossighelper @@ -381,9 +382,33 @@ if sys.platform != 'win32': # Subclass to attach _tag pass + class _FileObjectMapping(Mapping): + def __init__(self, fd_dict): + self.fd_dict = fd_dict + + def __len__(self): + return len(self.fd_dict) + + def get(self, fileobj, default=None): + fd = _fileobj_to_fd(fileobj) + return self.fd_dict.get(fd, default) + + def __getitem__(self, fileobj): + value = self.get(fileobj) + if value is None: + raise KeyError("{!r} is not registered".format(fileobj)) + return value + + def __iter__(self): + return iter(self.fd_dict) + class _Selector(_SelectorMixin, selectors.BaseSelector): """A Selector for gi.events.GLibEventLoop registering python IO with GLib.""" + def __init__(self, context, loop): + super().__init__(context, loop) + self._map = _FileObjectMapping(self._fd_to_key) + def attach(self): self._source.attach(self._loop._context) @@ -432,15 +457,12 @@ if sys.platform != 'win32': # We could override modify, but it is only slightly when the "events" change. def get_key(self, fileobj): - fd = _fileobj_to_fd(fileobj) - return self._fd_to_key[fd] + return self._map[fileobj] def get_map(self): - """Return a mapping of file objects to selector keys.""" - # Horribly inefficient - # It should never be called and exists just to prevent issues if e.g. - # python decides to use it for debug purposes. - return {k.fileobj: k for k in self._fd_to_key.values()} + """Return a mapping of file objects or file descriptors to + selector keys.""" + return self._map else: diff --git a/tests/test_events.py b/tests/test_events.py index 81409abae..c075af095 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -45,6 +45,7 @@ import sys import gi import gi.events import asyncio +import socket import threading from gi.repository import GLib @@ -262,3 +263,63 @@ class GLibEventLoopPolicyTests(unittest.TestCase): GLib.MainLoop().run() loop.close() + + @unittest.skipIf(sys.platform == 'win32', 'add reader/writer not implemented') + def test_source_fileobj_fd(self): + """Regression test for + https://gitlab.gnome.org/GNOME/pygobject/-/issues/689 + """ + class Echo: + def __init__(self, sock, expect_bytes): + self.sock = sock + self.sent_bytes = 0 + self.expect_bytes = expect_bytes + self.done = asyncio.Future() + self.data = bytes() + + def send(self): + if self.done.done(): + return + if self.sent_bytes < len(self.data): + self.sent_bytes += self.sock.send( + self.data[self.sent_bytes:]) + print('sent', self.data) + if self.sent_bytes >= self.expect_bytes: + self.done.set_result(None) + self.sock.shutdown(socket.SHUT_WR) + + def recv(self): + if self.done.done(): + return + self.data += self.sock.recv(self.expect_bytes) + print('received', self.data) + if len(self.data) >= self.expect_bytes: + self.sock.shutdown(socket.SHUT_RD) + + async def run(): + loop = asyncio.get_running_loop() + s1, s2 = socket.socketpair() + sample = b'Hello!' + e = Echo(s1, len(sample)) + # register using file object and file descriptor + loop.add_reader(s1, e.recv) + loop.add_writer(s1.fileno(), e.send) + s2.sendall(sample) + await asyncio.wait_for(e.done, timeout=2.0) + echo = bytes() + for _ in range(len(sample)): + echo += s2.recv(len(sample)) + if len(echo) == len(sample): + break + # remove using file object and file descriptor + loop.remove_reader(s1) + loop.remove_writer(s1.fileno()) + s1.close() + s2.close() + # check if the data was echoed correctly + self.assertEqual(sample, echo) + + policy = self.create_policy() + loop = policy.get_event_loop() + loop.run_until_complete(run()) + loop.close() -- GitLab