diff --git a/houdini/data/__init__.py b/houdini/data/__init__.py index a998b45..f16b78a 100644 --- a/houdini/data/__init__.py +++ b/houdini/data/__init__.py @@ -28,7 +28,7 @@ class BaseCrumbsCollection(dict): try: return self[k] except KeyError as e: - query = self._model.load(parent=self._inventory_model).where( + query = self._inventory_model.load(parent=self._model).where( (self._inventory_key_column == self._inventory_id) & (self._inventory_value_column == k) ) if self._is_inventory else self._model.query.where(self._model_key_column == k) result = await query.gino.first() @@ -38,25 +38,36 @@ class BaseCrumbsCollection(dict): raise e async def set(self, k=None, **kwargs): - if self._is_inventory and k: + if self._is_inventory: kwargs = {self._inventory_key: self._inventory_id, self._inventory_value: k} - self[k] = await self._inventory_model.create(**kwargs) + model_instance = await self._inventory_model.create(**kwargs) + k = getattr(model_instance, self._inventory_value) + self[k] = model_instance else: model_instance = await self._model.create(**kwargs) k = getattr(model_instance, self._key) self[k] = model_instance return self[k] + async def delete(self, k): + query = self._inventory_model.delete.where( + (self._inventory_key_column == self._inventory_id) & (self._inventory_value_column == k) + ) if self._is_inventory else self._model.delete.where(self._model_key_column == k) + await query.gino.status() + if k in self: + del self[k] + async def __collect(self): - query = self._model.load(parent=self._inventory_model).where( + query = self._inventory_model.load(parent=self._model).where( self._inventory_key_column == self._inventory_id ) if self._is_inventory else self._model.query async with db.transaction(): collected = query.gino.iterate() - self.update( - {getattr(model_instance, self._key): model_instance async for model_instance in collected} - ) + self.update({ + getattr(model_instance, self._inventory_value if self._is_inventory else self._key): model_instance + async for model_instance in collected + }) @classmethod async def get_collection(cls, *args, **kwargs):