在獨立子線程中執行非同步操作
第一次寫 Python 非同步語法的紀錄,網路上很多非同步文章,但是沒看過有把非同步操作放在獨立線程中執行的程式碼實例,於是自己寫了一個。
由於用了 viztracer 分析好像有點厲害所以把他搬到文檔庫,不然這篇原本放在備忘錄。
程式碼說明
設計思路是把函式和參數打包後丟給事件迴圈運行,用一個單獨的線程用於執行事件迴圈,再把任務註冊到這個事件迴圈。
首先先建立一個 dataclass 用於把要執行函式的以及函式輸入打包
@dataclass
class Task:
task_id: str
func: Callable[..., Any]
args: Tuple[Any, ...] = ()
kwargs: Optional[Dict[str, Any]] = None
def __post_init__(self) -> None:
self.kwargs = self.kwargs or {}
接下來就是建立負責管理事件迴圈和線程的類別,初始化如下,項目有點多。
- max_workers 用於限制最高並發數量
- is_running 是程式旗標,標記線程是否還在運行
- sem 是限制最高並發的 semaphore 鎖
- task_queue 用於緩衝任務,存放還未執行的任務列表,因為有 max_workers 限制最高並發數量
- current_tasks 是馬上要執行的任務列表,內容是從 task_queue 取出的
- results 用於儲存運行結果,使用字典以便根據 task_id 尋找結果,queue 無法完成這項要求
class AsyncService:
def __init__(self, logger: Logger, max_workers: int = 5) -> None:
self.max_workers = max_workers
self.logger = logger
self.is_running = False
self.loop: Optional[asyncio.AbstractEventLoop] = None
self.sem = asyncio.Semaphore(self.max_workers)
self.thread: Optional[threading.Thread] = None
self._lock = threading.Lock()
self.task_queue: queue.Queue[Task] = queue.Queue()
self.current_tasks: list[asyncio.Task[Any]] = []
self.results: Dict[str, Any] = {}
接下來我們介紹程式架構:
add_task
和add_tasks
作為外部接口把任務放進 task_queue 中- 每次
add_task
呼叫_ensure_thread_active
確認子線程是否存活 - 子線程執行
_start_event_loop
,這個函式會呼叫_schedule_tasks
,並且使用 try-finally 語法管理事件迴圈的關閉 _schedule_tasks
是一個無限迴圈,用於從 task_queue 中取出任務,使用 asyncio.create_task 註冊到事件迴圈_run_task
把輸入函式解包並且 await 執行,再把輸出結果放進 results 字典
有點繞,先看最簡單的 _run_task
,把 Task
dataclass 內容取出後執行再用線程鎖控制輸出寫入,雖然只有單線程應該不需要線程鎖,但是考量三個原因還是把他加上去:
- 防範於未然,字典不是線程安全的,要是哪天忘了他只能用單線程就造成競爭危害了
- 雖然 99.9999% 確定不會造成競爭危害(因為事件迴圈本質是順序執行,除非在寫入時也 await 才有可能造成非原子操作導致競爭危害),但是不想賭那個 0.00001% 的問題
- 相較 io 任務而言這個鎖的開銷簡直微乎其微
async def _run_task(self, task: Task) -> Any:
async with self.sem:
print(
f"Task {task.func.__name__} with args {task.args} and kwargs {task.kwargs} start running!"
)
try:
result = await task.func(*task.args, **task.kwargs) # type: ignore
with self._lock:
self.results[task.task_id] = result
return result
except Exception as e:
self.logger.error(f"Error processing task {task.task_id}: {e}")
with self._lock:
self.results[task.task_id] = None
剛才介紹完 _run_task
,接下來介紹呼叫 _run_task
的 _schedule_tasks
,前者真正執行任務,後者管理任務,是一個中間人的角色,負責從 task_queue 中取出任務註冊到事件迴圈中,並且放到 current_tasks 這個列表中準備執行。
async def _schedule_tasks(self) -> None:
while True:
self.current_tasks = [task for task in self.current_tasks if not task.done()]
if self.task_queue.empty() and not self.current_tasks:
break
while not self.task_queue.empty() and len(self.current_tasks) < self.max_workers:
try:
task = self.task_queue.get_nowait()
task_obj = asyncio.create_task(self._run_task(task))
self.current_tasks.append(task_obj)
except queue.Empty:
break
if self.current_tasks:
await asyncio.wait(self.current_tasks, return_when=asyncio.FIRST_COMPLETED)
現在完成了事件註冊和運行事件,為了把非同步操作放在獨立線程中執行,還缺少運行事件迴圈以及把事件迴圈放到線程中執行這兩件事情,使用 _start_event_loop
還有 _ensure_thread_active
完成,完整程式碼如下,也可以在我的 Github 中找到:
import asyncio
import queue
import threading
import time
from dataclasses import dataclass
from logging import Logger, getLogger
from typing import Any, Dict, Tuple, Callable, Optional
from help import BLOCK_MSG, NOT_BLOCK_MSG, io_task, print_thread_id, timer
@dataclass
class Task:
task_id: str
func: Callable[..., Any]
args: Tuple[Any, ...] = ()
kwargs: Optional[Dict[str, Any]] = None
def __post_init__(self) -> None:
self.kwargs = self.kwargs or {}
class AsyncService:
def __init__(self, logger: Logger, max_workers: int = 5) -> None:
# 載入變數
self.max_workers = max_workers
self.logger = logger
# 任務運行相關設定
self.is_running = False
self.loop: Optional[asyncio.AbstractEventLoop] = None
self.sem = asyncio.Semaphore(self.max_workers)
self.thread: Optional[threading.Thread] = None
self._lock = threading.Lock()
# 儲存任務和結果的資料結構
self.task_queue: queue.Queue[Task] = queue.Queue()
self.current_tasks: list[asyncio.Task[Any]] = []
self.results: Dict[str, Any] = {}
def add_task(self, task: Task) -> None:
self.task_queue.put(task)
self._ensure_thread_active()
def add_tasks(self, tasks: list[Task]) -> None:
for task in tasks:
self.task_queue.put(task)
self._ensure_thread_active()
def fetch_result(self, task_id: str) -> Optional[Any]:
with self._lock:
return self.results.pop(task_id, None)
def fetch_results(self, max_results: int = 0) -> Dict[str, Any]:
with self._lock:
if max_results <= 0:
results_to_return = self.results.copy()
self.results.clear()
return results_to_return
keys = list(self.results.keys())[:max_results]
return {key: self.results.pop(key) for key in keys}
def shutdown(self, timeout: Optional[float] = None) -> None:
if self.thread is not None:
self.thread.join(timeout=timeout)
print(f"\n===no job! clearing thread {self.thread.native_id}===")
self.thread = None
self.is_running = False
print(f"===thread cleared! result: {self.thread}===\n")
def _ensure_thread_active(self) -> None:
with self._lock:
if not self.is_running or self.thread is None or not self.thread.is_alive():
self.is_running = True
self.thread = threading.Thread(target=self._start_event_loop)
self.thread.start()
def _start_event_loop(self) -> None:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
try:
self.loop.run_until_complete(self._schedule_tasks())
finally:
self.loop.close()
self.loop = None
self.is_running = False
self.current_tasks.clear()
async def _schedule_tasks(self) -> None:
while True:
self.current_tasks = [task for task in self.current_tasks if not task.done()]
if self.task_queue.empty() and not self.current_tasks:
break
while not self.task_queue.empty() and len(self.current_tasks) < self.max_workers:
try:
task = self.task_queue.get_nowait()
task_obj = asyncio.create_task(self._run_task(task))
self.current_tasks.append(task_obj)
except queue.Empty:
break
if self.current_tasks:
await asyncio.wait(self.current_tasks, return_when=asyncio.FIRST_COMPLETED)
async def _run_task(self, task: Task) -> Any:
async with self.sem:
print(
f"Task {task.func.__name__} with args {task.args} and kwargs {task.kwargs} start running!"
)
try:
result = await task.func(*task.args, **task.kwargs) # type: ignore
with self._lock:
self.results[task.task_id] = result
return result
except Exception as e:
self.logger.error(f"Error processing task {task.task_id}: {e}")
with self._lock:
self.results[task.task_id] = None
@timer
def test() -> None:
print_thread_id()
logger = getLogger()
task_groups = [
[(1, "A1"), (2, "A2"), (3, "A3")],
[(3, "B1"), (4, "B2"), (5, "B3")],
[(3, "C1"), (4, "C2"), (5, "C3")],
[(1, "D1"), (2, "D2"), (3, "D3")],
]
manager = AsyncService(logger, max_workers=5)
# 新增第一批任務
for group in task_groups[:-1]:
tasks = [Task(task[1], io_task, task) for task in group]
manager.add_tasks(tasks)
print(NOT_BLOCK_MSG)
# 模擬主執行緒工作需要 2.5 秒,在程式中間取得結果
# 會顯示 A1/A2 的結果,因為它們在 2.5 秒後完成
time.sleep(2.5)
results = manager.fetch_results()
print(NOT_BLOCK_MSG, "(2s waiting for main thread itself)") # not blocked
for result in results:
print(result)
# 等待子執行緒結束,造成阻塞
manager.shutdown()
print(BLOCK_MSG)
for _ in range(3):
results = manager.fetch_results()
for result in results:
print(result)
# 在thread關閉後新增第二批任務
tasks = [Task(task[1], io_task, task) for task in task_groups[-1]]
manager.add_tasks(tasks)
manager.shutdown()
results = manager.fetch_results()
for result in results:
print(result)
if __name__ == "__main__":
print_thread_id()
test()
test
函式是一個簡單的使用範例,實際運行是 11 秒,讀者可以自行計算秒數驗證是否和理論相符。