在獨立子線程中執行非同步操作
第一次寫 Python 非同步語法的紀錄,網路上很多非同步文章,但是沒看過有把非同步操作放在獨立線程中執行的程式碼實例,於是自己寫了一個。
程式說明
設計思路是把函式和參數打包後丟給事件迴圈運行,用一個單獨的線程用於執行事件迴圈,再把任務註冊到這個事件迴圈。
首先先建立一個 dataclass 用於把要執行函式的以及函式輸入打包
@dataclass
class Task:
task_id: str
func: Callable[..., Any]
args: Tuple[Any, ...] = ()
kwargs: Optional[Dict[str, Any]] = None
def __post_init__(self) -> None:
self.kwargs = self.kwargs or {}
接下來就是建立負責管理事件迴圈和線程的類別,初始化如下,項目有點多。
- max_workers 用於限制最高並發數量
- is_running 是程式旗標,標記線程是否還在運行
- sem 是限制最高並發的 semaphore 鎖
- task_queue 用於緩衝任務,存放還未執行的任務列表,因為有 max_workers 限制最高並發數量
- current_tasks 是馬上要執行的任務列表,內容是從 task_queue 取出的
- results 用於儲存運行結果,使用字典以便根據 task_id 尋找結果,queue 無法完成這項要求
class AsyncService:
def __init__(self, logger: Logger, max_workers: int = 5) -> None:
self.max_workers = max_workers
self.logger = logger
self.is_running = False
self.loop: Optional[asyncio.AbstractEventLoop] = None
self.sem = asyncio.Semaphore(self.max_workers)
self.thread: Optional[threading.Thread] = None
self._lock = threading.Lock()
self.task_queue: queue.Queue[Task] = queue.Queue()
self.current_tasks: list[asyncio.Task[Any]] = []
self.results: Dict[str, Any] = {}
接下來我們介紹程式架構:
add_task
和add_tasks
作為外部接口把任務放進 task_queue 中- 每次
add_task
呼叫_ensure_thread_active
確認子線程是否存活 - 子線程執行
_start_event_loop
,這個函式會呼叫_schedule_tasks
,並且使用 try-finally 語法管理事件迴圈的關閉 _schedule_tasks
是一個無限迴圈,用於從 task_queue 中取出任務,使用 asyncio.create_task 註冊到事件迴圈_run_task
把輸入函式解包並且 await 執行,再把輸出結果放進 results 字典
有點繞,先看最簡單的 _run_task
,把 Task
dataclass 內容取出後執行再用線程鎖控制輸出寫入,雖然只有單線程應該不 需要線程鎖,但是考量三個原因還是把他加上去:
- 防範於未然,字典不是線程安全的,要是哪天忘了他只能用單線程就造成競爭危害了
- 雖然 99.9999% 確定不會造成競爭危害(因為事件迴圈本質是順序執行,除非在寫入時也 await 才有可能造成非原子操作導致競爭危害),但是不想賭那個 0.00001% 的問題
- 相較 io 任務而言這個鎖的開銷簡直微乎其微
async def _run_task(self, task: Task) -> Any:
async with self.sem:
print(
f"Task {task.func.__name__} with args {task.args} and kwargs {task.kwargs} start running!"
)
try:
result = await task.func(*task.args, **task.kwargs) # type: ignore
with self._lock:
self.results[task.task_id] = result
return result
except Exception as e:
self.logger.error(f"Error processing task {task.task_id}: {e}")
with self._lock:
self.results[task.task_id] = None
剛才介紹完 _run_task
,接下來介紹呼叫 _run_task
的 _schedule_tasks
,前者真正執行任務,後者管理任務,是一個中間人的角色,負責從 task_queue 中取出任務註冊到事件迴圈中,並且放到 current_tasks 這個列表中準備執行。
async def _schedule_tasks(self) -> None:
while True:
self.current_tasks = [task for task in self.current_tasks if not task.done()]
if self.task_queue.empty() and not self.current_tasks:
break
while not self.task_queue.empty() and len(self.current_tasks) < self.max_workers:
try:
task = self.task_queue.get_nowait()
task_obj = asyncio.create_task(self._run_task(task))
self.current_tasks.append(task_obj)
except queue.Empty:
break
if self.current_tasks:
await asyncio.wait(self.current_tasks, return_when=asyncio.FIRST_COMPLETED)
現在完成了事件註冊和運行事件,為了把非同步操作放在獨立線程中執行,還缺少運行事件迴圈以及把事件迴圈放到線程中執行這兩件事情,使用 _start_event_loop
還有 _ensure_thread_active
完成,完整程式碼如下,也可以在我的 Github 中找到:
import asyncio
import queue
import threading
import time
from dataclasses import dataclass
from logging import Logger, getLogger
from typing import Any, Dict, Tuple, Callable, Optional
from help import BLOCK_MSG, NOT_BLOCK_MSG, io_task, print_thread_id, timer
@dataclass
class Task:
task_id: str
func: Callable[..., Any]
args: Tuple[Any, ...] = ()
kwargs: Optional[Dict[str, Any]] = None
def __post_init__(self) -> None:
self.kwargs = self.kwargs or {}
class AsyncService:
def __init__(self, logger: Logger, max_workers: int = 5) -> None:
# 載入變數
self.max_workers = max_workers
self.logger = logger
# 任務運行相關設定
self.is_running = False
self.loop: Optional[asyncio.AbstractEventLoop] = None
self.sem = asyncio.Semaphore(self.max_workers)
self.thread: Optional[threading.Thread] = None
self._lock = threading.Lock()
# 儲存任務和結果的資料結構
self.task_queue: queue.Queue[Task] = queue.Queue()
self.current_tasks: list[asyncio.Task[Any]] = []
self.results: Dict[str, Any] = {}
def add_task(self, task: Task) -> None:
self.task_queue.put(task)
self._ensure_thread_active()
def add_tasks(self, tasks: list[Task]) -> None:
for task in tasks:
self.task_queue.put(task)
self._ensure_thread_active()
def fetch_result(self, task_id: str) -> Optional[Any]:
with self._lock:
return self.results.pop(task_id, None)
def fetch_results(self, max_results: int = 0) -> Dict[str, Any]:
with self._lock:
if max_results <= 0:
results_to_return = self.results.copy()
self.results.clear()
return results_to_return
keys = list(self.results.keys())[:max_results]
return {key: self.results.pop(key) for key in keys}
def shutdown(self, timeout: Optional[float] = None) -> None:
if self.thread is not None:
self.thread.join(timeout=timeout)
print(f"\n===no job! clearing thread {self.thread.native_id}===")
self.thread = None
self.is_running = False
print(f"===thread cleared! result: {self.thread}===\n")
def _ensure_thread_active(self) -> None:
with self._lock:
if not self.is_running or self.thread is None or not self.thread.is_alive():
self.is_running = True
self.thread = threading.Thread(target=self._start_event_loop)
self.thread.start()
def _start_event_loop(self) -> None:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
try:
self.loop.run_until_complete(self._schedule_tasks())
finally:
self.loop.close()
self.loop = None
self.is_running = False
self.current_tasks.clear()
async def _schedule_tasks(self) -> None:
while True:
self.current_tasks = [task for task in self.current_tasks if not task.done()]
if self.task_queue.empty() and not self.current_tasks:
break
while not self.task_queue.empty() and len(self.current_tasks) < self.max_workers:
try:
task = self.task_queue.get_nowait()
task_obj = asyncio.create_task(self._run_task(task))
self.current_tasks.append(task_obj)
except queue.Empty:
break
if self.current_tasks:
await asyncio.wait(self.current_tasks, return_when=asyncio.FIRST_COMPLETED)
async def _run_task(self, task: Task) -> Any:
async with self.sem:
print(
f"Task {task.func.__name__} with args {task.args} and kwargs {task.kwargs} start running!"
)
try:
result = await task.func(*task.args, **task.kwargs) # type: ignore
with self._lock:
self.results[task.task_id] = result
return result
except Exception as e:
self.logger.error(f"Error processing task {task.task_id}: {e}")
with self._lock:
self.results[task.task_id] = None
@timer
def test() -> None:
print_thread_id()
logger = getLogger()
task_groups = [
[(1, "A1"), (2, "A2"), (3, "A3")],
[(3, "B1"), (4, "B2"), (5, "B3")],
[(3, "C1"), (4, "C2"), (5, "C3")],
[(1, "D1"), (2, "D2"), (3, "D3")],
]
manager = AsyncService(logger, max_workers=5)
# 新增第一批任務
for group in task_groups[:-1]:
tasks = [Task(task[1], io_task, task) for task in group]
manager.add_tasks(tasks)
print(NOT_BLOCK_MSG)
# 模擬主執行緒工作需要 2.5 秒,在程式中間取得結果
# 會顯示 A1/A2 的結果,因為它們在 2.5 秒後完成
time.sleep(2.5)
results = manager.fetch_results()
print(NOT_BLOCK_MSG, "(2s waiting for main thread itself)") # not blocked
for result in results:
print(result)
# 等待子執行緒結束,造成阻塞
manager.shutdown()
print(BLOCK_MSG)
for _ in range(3):
results = manager.fetch_results()
for result in results:
print(result)
# 在thread關閉後新增第二批任務
tasks = [Task(task[1], io_task, task) for task in task_groups[-1]]
manager.add_tasks(tasks)
manager.shutdown()
results = manager.fetch_results()
for result in results:
print(result)
if __name__ == "__main__":
print_thread_id()
test()
test
函式是一個簡單的使用範例,實際運行是 11 秒,讀者可以自行計算秒數驗證是否和理論相符。
自我檢討和心得
搜尋資料時看到有人建議撰寫主程式是非同步,然後把同步語法放到子線程中執行會比較好,寫的時候不太認同,真的用函式的時候就認同了,因為即使已經包裝成只要呼叫 add_task
和輸入 Task
,實際使用時還是不太方便。
第二個是層層包裹的語句造成理解不易,使用 add_task
加入任務後會經過 _ensure_thread_active
確認線程是否存活並且建立線程,線程裡面要使用 _start_event_loop
建立事件迴圈,再使用 _schedule_tasks
把事件註冊到迴圈中,最後用 _run_task
把 Task
dataclass 解包並且執行。這呼應第一個問題:如果去掉在子線程執行事件迴圈這件事,就可以刪掉前兩個方法,簡化為只需要註冊和運行而已。會這樣寫的原因除了自己想練習以外,也是因為前陣子寫了一個「把任務丟到子線程中執行」,所以用同樣想法寫了事件迴圈版本,結果比想像中的麻煩多了。不過都是試了才知道,畢竟網路上又沒這種文章。
第三是子線程中包含多個事件迴圈的運行管理,但筆者還沒到那個程度,以這個架構繼續延伸的話應該是輸入時加上 loop id 選擇要使用哪個迴圈。
第四有關記憶體效率,task_queue
和 current_tasks
疊床架屋,task_queue
用於暫存還沒執行的任務,current_tasks
存放已經從 task_queue
取出準備要執行的任務,要兩個物件管理有點浪費資源,有 semaphore 應該就不需要這兩個東西。附帶一提 results
不使用 queue 的原因是用戶可能會想根據 task_id 取得結果,但是 queue 只能從頭尾取值達不到這項要求。
最後補充,這個腳本跑 mypy --strict 可以過的唷。
更新:移除佇列版本
關於檢討中說到的疊床架屋問題,寫了一個不需要 task_queue
和 current_tasks
的版本,並且根據這篇文章使用 run_coroutine_threadsafe
和 call_soon_threadsafe
在主線程要求子線程運行任務,並且加上 threading.Event 確保事件迴圈正常啟動避免死鎖。
class AsyncService:
def __init__(self, logger: Logger, max_workers: int = 5) -> None:
self.logger = logger
# 任務運行相關設定
self._running_tasks = 0
self.loop: Optional[asyncio.AbstractEventLoop] = None
self.thread: Optional[threading.Thread] = None
self._lock = threading.Lock()
self._loop_ready = threading.Event()
self.sem = asyncio.Semaphore(max_workers)
# 儲存結果
self.results: Dict[str, Any] = {}
def add_task(self, task: Task) -> None:
self._ensure_thread_active()
with self._lock:
self._running_tasks += 1
assert self.loop is not None
asyncio.run_coroutine_threadsafe(self._schedule_tasks(task), self.loop)
def add_tasks(self, tasks: list[Task]) -> None:
for task in tasks:
self.add_task(task)
def fetch_result(self, task_id: str) -> Optional[Any]:
with self._lock:
return self.results.pop(task_id, None)
def fetch_results(self, max_results: int = 0) -> Dict[str, Any]:
with self._lock:
if max_results <= 0:
results_to_return = self.results.copy()
self.results.clear()
return results_to_return
keys = list(self.results.keys())[:max_results]
return {key: self.results.pop(key) for key in keys}
def shutdown(self, timeout: Optional[float] = None) -> None:
if self.thread is None or self.loop is None:
return
while True:
with self._lock:
if self._running_tasks == 0:
break
self.loop.call_soon_threadsafe(self.loop.stop) # 停止事件迴圈
self.thread.join(timeout=timeout)
print(f"\n===no job! clearing thread {self.thread.native_id}===")
self.thread = None
print(f"===thread cleared! result: {self.thread}===\n")
def _ensure_thread_active(self) -> None:
with self._lock:
if self.thread is None or not self.thread.is_alive():
self._loop_ready.clear()
self.thread = threading.Thread(target=self._start_event_loop)
self.thread.start()
self._loop_ready.wait() # 等待事件迴圈啟動
def _start_event_loop(self) -> None:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self._loop_ready.set()
self.loop.run_forever()
try:
self.loop.close()
finally:
self.loop = None
self._loop_ready.clear()
async def _schedule_tasks(self, task: Task) -> None:
async with self.sem:
print(
f"Task {task.func.__name__} with args {task.args} and kwargs {task.kwargs} start running!"
)
try:
result = await task.func(*task.args, **task.kwargs)
with self._lock:
self.results[task.task_id] = result
except Exception as e:
self.logger.error(f"Error processing task {task.task_id}: {e}")
with self._lock:
self.results[task.task_id] = None
finally:
with self._lock:
self._running_tasks -= 1
討論:死鎖問題
藉此機會練習多線程的問題,剛改成這個版本時沒有使用 threading.Event
會產生死鎖,以下分析死鎖產生原因:
- 在
add_task
時馬上呼叫_ensure_thread_active
,得到鎖,建立並且啟動線程,釋放鎖,return,同一時間子線程正在執行_start_event_loop
- 回到
add_task
,取得鎖,執行run_coroutine_threadsafe
,但是_start_event_loop
中建立 self.loop 的工作尚未完成,引發 NoneTypeError - 鎖釋放失敗,下一次
add_task
又要求鎖,造成死鎖(這也會造成run_coroutine_threadsafe
本身的鎖產生死鎖)
加上 Event 則確保 self.loop 成功建立,在 asyncio.set_event_loop
完成之後才發送 set
訊號告訴主線程可以繼續工作,避免死鎖問題。這裡和網路教學相反的是是由子線程告訴主線程可以繼續了,網路教學通常是由主線程輸入 Event 告訴子線程開始工作。
# 一開始沒有鎖的版本,方便讀者比較
def _start_event_loop(self) -> None:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
# self._loop_ready.set() # 不使用Event
self.loop.run_forever() # 阻塞線程
# ...
def _ensure_thread_active(self) -> None:
with self._lock:
if self.thread is None or not self.thread.is_alive():
# self._loop_ready.clear() # 不使用Event
self.thread = threading.Thread(target=self._start_event_loop)
self.thread.start()
# self._loop_ready.wait() # 不使用Event