代码之家  ›  专栏  ›  技术社区  ›  darkie7

我的一些郊游并不(总是)开始

  •  0
  • darkie7  · 技术社区  · 1 年前

    我正在与Langchain和FastApi合作。基本上,我正在创建一个StreamingResponse,它流式返回JSON差异。diff是通过将JSON文档与其以前的版本进行比较来创建的。JSON的各种属性是由它通过langchain链从OpenAI接收的信息生成的(使用LCEL( https://python.langchain.com/docs/expression_language/ )(也流式传输)。

    为了获得所有属性,我运行了一个名为“run”的函数,该函数会触发其他(异步)函数。这些async函数向async.queue添加了一个jsondiff anext 方法卸载队列。

    代码相当复杂,我试图将相关部分浓缩在这里:

    这是创建异步生成器的代码。RunAgent的run方法实际上构建文档并填充队列。

    from jsonpatch import JsonPatch
    import asyncio
    from pydantic import BaseModel
    from src.llm.chain import meta_chain, sub_object_meta_chain, sub_object_details_chain
    from src.schema import MyObject
    
    
    class IterableAgent(BaseModel):
        queue: asyncio.Queue | None = None
        done: bool = False
        timeout: int = 10
    
        async def __anext__(self):
            if self.done and self.queue.empty():
                raise StopAsyncIteration
            try:
                return await asyncio.wait_for(self.queue.get(), timeout=self.timeout)
            except asyncio.TimeoutError:
                raise StopAsyncIteration
    
        def __aiter__(self):
            self.done = False
            self.queue = asyncio.Queue()
            asyncio.create_task(self.run())
            return self
    
        @abstractmethod
        def run():
            raise NotImplementedError
    
    
    class RunAgent(IterableAgent):
        _last_object: str = "{}"
    
        def _get_patch(self) -> str:
            new_object = self.object.model_dump_json()
            patch = JsonPatch.from_diff(self._last_object, new_object).patch
            self._last_object = new_object
            return patch
    
        def _continue_response(self) -> None:
            patch = self._get_patch()
            if patch:
                self.queue.put_nowait(patch)
    
        async def _add_meta(self):
            inputs = dict()  # Any inputs I need here for the chain...
            async for result in meta_chain.astream(inputs):
                self.object.meta = result
                self._continue_response()
    
        async def _add_sub_objects_meta(self):
            inputs = dict()  # Any inputs I need here for the chain...
            async for result in sub_object_meta_chain.astream(inputs):
                self.object.sub_objects = result
                self._continue_response()
    
        async def _add_sub_objects_details(self, so):
            inputs = dict(so=so)  # Any inputs I need here for the chain...
            async for result in sub_object_details_chain.astream(inputs):
                so.details = result
                self._continue_response()
    
        async def run(self):
            self.object = MyObject()
            await self._add_meta()
            await self._add_sub_objects_meta()
            cors = [self._add_sub_objects_details(so) for so in self.object.sub_objects]
            # I tried to create tasks instead of just having coroutines, no luck there...
            tasks = [asyncio.create_task(x) for x in cors]
            await asyncio.gather(*tasks)
            self.done = True
    

    然后,在我的FastApi路由器中还有一小段代码,它使用生成器向StreamingResponse提供信息。

    
    @router.post("/new")
    async def new_document(req: MyRequest):
        \# The agent is implemented to be an AsyncIterator
        agent = RunAgent() # I would typically unpack MyRequest and feed the relevant data into the  agent initialization
        return StreamingResponse(agent, media_type="application/json")
    

    我的问题

    现在,这很有效,但有时,它只是中途停止。排队的时间刚好超时。任务似乎还没有开始。似乎已经开始的任何任务/协同程序都完成了。例如,我看到的是生成了基本对象的元数据,但没有生成子对象的元数据。或者其中一个子对象的详细信息丢失。我认为有一些比赛条件,我只知道在哪里。。。。

    0 回复  |  直到 1 年前
        1
  •  2
  •   jsbueno    1 年前

    Python asyncio有一个特殊的设计选择,这会导致一个问题,很可能就是您遇到的问题:

    使用创建的任务 .create_task 不会被事件循环硬引用。(它只保留了对它们的微弱引用)。虽然这有时不会出现在简单的运行中,或者当任务很少时——因为弱引用的任务实际上是在异步循环核心的第一次执行中开始的,当有很多任务(在我进行的一些测试中大约有2000个)时,或者可能更少,当任务创建在多个地方时,比如在这段代码中,任务可能会消失得无影无踪。

    这很可能是您的问题,解决方案只是保留对由创建的任务的引用 .create_task 调用-一个简单的类级别 set 可以保留这些(或实例级别,如果您有 __init__ 方法):

    ...
    
    class IterableAgent(BaseModel):
        queue: asyncio.Queue | None = None
        done: bool = False
        timeout: int = 10
        running_tasks: set[asyncio.Task] = set()
    
        ...
    
        def __aiter__(self):
            self.done = False
            self.queue = asyncio.Queue()
            task = asyncio.create_task(self.run())
            self.running_tasks.add(task)
            task.add_done_callback(self.running_tasks.discard)
            return self
        
       ...
    
    

    (当然,如果有其他调用,请使用相同的模式 .create_task )。

    至于是否记录了这种行为,请注意“重要”注释: https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task