Skip to content

Commit

Permalink
- correctly split Detections and Image_Metadata
Browse files Browse the repository at this point in the history
- minor typing fixes

-improve test speed
  • Loading branch information
denniswittich committed Nov 8, 2024
1 parent 76fd8ce commit e823122
Show file tree
Hide file tree
Showing 16 changed files with 108 additions and 105 deletions.
2 changes: 1 addition & 1 deletion learning_loop_node/data_classes/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from enum import Enum
from typing import Optional, Union

from .detections import Point, Shape
from .general import Category, Context
from .image_metadata import Point, Shape

KWONLY_SLOTS = {'kw_only': True, 'slots': True} if sys.version_info >= (3, 10) else {}

Expand Down
20 changes: 10 additions & 10 deletions learning_loop_node/detector/detector_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,13 @@ async def upload(sid, data: Dict) -> Optional[Dict]:
detection_data = data.get('detections', {})
if detection_data and self.detector_logic.is_initialized:
try:
detections = from_dict(data_class=ImageMetadata, data=detection_data)
image_metadata = from_dict(data_class=ImageMetadata, data=detection_data)
except Exception as e:
self.log.exception('could not parse detections')
return {'error': str(e)}
detections = self.add_category_id_to_detections(self.detector_logic.model_info, detections)
image_metadata = self.add_category_id_to_detections(self.detector_logic.model_info, image_metadata)
else:
detections = ImageMetadata()
image_metadata = ImageMetadata()

tags = data.get('tags', [])
tags.append('picked_by_system')
Expand All @@ -190,7 +190,7 @@ async def upload(sid, data: Dict) -> Optional[Dict]:

loop = asyncio.get_event_loop()
try:
await loop.run_in_executor(None, self.outbox.save, data['image'], detections, tags, source, creation_date)
await loop.run_in_executor(None, self.outbox.save, data['image'], image_metadata, tags, source, creation_date)
except Exception as e:
self.log.exception('could not upload via socketio')
return {'error': str(e)}
Expand Down Expand Up @@ -377,28 +377,28 @@ async def upload_images(self, images: List[bytes], source: Optional[str], creati
for image in images:
await loop.run_in_executor(None, self.outbox.save, image, ImageMetadata(), ['picked_by_system'], source, creation_date)

def add_category_id_to_detections(self, model_info: ModelInformation, detections: ImageMetadata):
def add_category_id_to_detections(self, model_info: ModelInformation, image_metadata: ImageMetadata):
def find_category_id_by_name(categories: List[Category], category_name: str):
category_id = [category.id for category in categories if category.name == category_name]
return category_id[0] if category_id else ''

for box_detection in detections.box_detections:
for box_detection in image_metadata.box_detections:
category_name = box_detection.category_name
category_id = find_category_id_by_name(model_info.categories, category_name)
box_detection.category_id = category_id
for point_detection in detections.point_detections:
for point_detection in image_metadata.point_detections:
category_name = point_detection.category_name
category_id = find_category_id_by_name(model_info.categories, category_name)
point_detection.category_id = category_id
for segmentation_detection in detections.segmentation_detections:
for segmentation_detection in image_metadata.segmentation_detections:
category_name = segmentation_detection.category_name
category_id = find_category_id_by_name(model_info.categories, category_name)
segmentation_detection.category_id = category_id
for classification_detection in detections.classification_detections:
for classification_detection in image_metadata.classification_detections:
category_name = classification_detection.category_name
category_id = find_category_id_by_name(model_info.categories, category_name)
classification_detection.category_id = category_id
return detections
return image_metadata

def register_sio_events(self, sio_client: AsyncClient):
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ def forget_old_detections(self) -> None:
for detection in self.recent_observations
if not detection.is_older_than(self.reset_time)]

def get_causes_to_upload(self, detections: ImageMetadata) -> List[str]:
def get_causes_to_upload(self, image_metadata: ImageMetadata) -> List[str]:
causes = set()
for detection in detections.box_detections + detections.point_detections + detections.segmentation_detections + detections.classification_detections:
for detection in image_metadata.box_detections + image_metadata.point_detections + image_metadata.segmentation_detections + image_metadata.classification_detections:
if isinstance(detection, SegmentationDetection):
# self.recent_observations.append(Observation(detection))
causes.add('segmentation_detection')
Expand Down
8 changes: 4 additions & 4 deletions learning_loop_node/detector/inbox_filter/relevance_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(self, outbox: Outbox) -> None:
self.outbox: Outbox = outbox

def may_upload_detections(self,
dets: ImageMetadata,
image_metadata: ImageMetadata,
cam_id: str,
raw_image: bytes,
tags: List[str],
Expand All @@ -24,11 +24,11 @@ def may_upload_detections(self,

if cam_id not in self.cam_histories:
self.cam_histories[cam_id] = CamObservationHistory()
causes = self.cam_histories[cam_id].get_causes_to_upload(dets)
if len(dets) >= 80:
causes = self.cam_histories[cam_id].get_causes_to_upload(image_metadata)
if len(image_metadata) >= 80:
causes.append('unexpected_observations_count')
if len(causes) > 0:
tags = tags if tags is not None else []
tags.extend(causes)
self.outbox.save(raw_image, dets, tags, source, creation_date)
self.outbox.save(raw_image, image_metadata, tags, source, creation_date)
return causes
5 changes: 3 additions & 2 deletions learning_loop_node/detector/rest/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import numpy as np
from fastapi import APIRouter, File, Header, Request, UploadFile
from fastapi.responses import JSONResponse

from ...data_classes.image_metadata import ImageMetadata

Expand All @@ -23,6 +22,7 @@ async def http_detect(
source: Optional[str] = Header(None, description='The source of the image (used by learning loop)'),
autoupload: Optional[str] = Header(None, description='Mode to decide whether to upload the image to the learning loop',
examples=['filtered', 'all', 'disabled']),
creation_date: Optional[str] = Header(None, description='The creation date of the image (used by learning loop)')
):
"""
Single image example:
Expand All @@ -46,7 +46,8 @@ async def http_detect(
camera_id=camera_id or mac or None,
tags=tags.split(',') if tags else [],
source=source,
autoupload=autoupload)
autoupload=autoupload,
creation_date=creation_date)
except Exception as exc:
logging.exception('Error during detection of image %s.', file.filename)
raise Exception(f'Error during detection of image {file.filename}.') from exc
Expand Down
11 changes: 5 additions & 6 deletions learning_loop_node/tests/detector/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ async def sio_client() -> AsyncGenerator[socketio.AsyncClient, None]:
try:
await sio.connect(f"ws://localhost:{detector_port}", socketio_path="/ws/socket.io")
try_connect = False
except Exception as e:
logging.warning(f"Connection failed with error: {str(e)}")
except Exception:
logging.exception("Connection failed with error:")
logging.warning('trying again')
await asyncio.sleep(5)
retry_count += 1
Expand All @@ -123,21 +123,20 @@ def mock_detector_logic():
class MockDetectorLogic(DetectorLogic): # pylint: disable=abstract-method
def __init__(self):
super().__init__('mock')
self.detections = ImageMetadata(
self.image_metadata = ImageMetadata(
box_detections=[BoxDetection(category_name="test",
category_id="1",
confidence=0.9,
x=0, y=0, width=10, height=10,
model_name="mock",
)]
)
)])

@property
def is_initialized(self):
return True

def evaluate_with_all_info(self, image: np.ndarray, tags: List[str], source: Optional[str] = None, creation_date: Optional[str] = None):
return self.detections
return self.image_metadata

return MockDetectorLogic()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

from dacite import from_dict

from ....data_classes.image_metadata import (BoxDetection, ImageMetadata, Point, PointDetection, SegmentationDetection,
Shape)
from ....data_classes import BoxDetection, ImageMetadata, Point, PointDetection, SegmentationDetection, Shape
from ....detector.inbox_filter.cam_observation_history import CamObservationHistory

dirt_detection = BoxDetection(category_name='dirt', x=0, y=0, width=100, height=100,
Expand Down
88 changes: 44 additions & 44 deletions learning_loop_node/tests/detector/test_client_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,60 +108,60 @@ async def test_about_endpoint(test_detector_node: DetectorNode):


async def test_model_version_api(test_detector_node: DetectorNode):
await asyncio.sleep(16)

response = requests.get(f'http://localhost:{GLOBALS.detector_port}/model_version', timeout=30)
assert response.status_code == 200, response.content
response_dict = json.loads(response.content)
assert response_dict['version_control'] == 'follow_loop'
assert response_dict['current_version'] == '1.1'
assert response_dict['target_version'] == '1.1'
assert response_dict['loop_version'] == '1.1'
assert response_dict['local_versions'] == ['1.1']
async def await_correct_response(target_values: dict) -> None:
response_dict = {}
for _ in range(20):
await asyncio.sleep(1)
response = requests.get(f'http://localhost:{GLOBALS.detector_port}/model_version', timeout=30)
if not response.status_code == 200:
continue
response_dict = json.loads(response.content)
for key, target_value in target_values.items():
if key == 'local_versions':
target_value = set(target_value)

response_value = response_dict.get(key, None)
if response_value != target_value:
break
return
raise Exception(f'Did not receive correct response: {response_dict} != {target_values}')

await await_correct_response({'version_control': 'follow_loop',
'current_version': '1.1',
'target_version': '1.1',
'loop_version': '1.1',
'local_versions': ['1.1']})

response = requests.put(f'http://localhost:{GLOBALS.detector_port}/model_version', data='1.0', timeout=30)
assert response.status_code == 200, response.content
response = requests.get(f'http://localhost:{GLOBALS.detector_port}/model_version', timeout=30)
assert response.status_code == 200, response.content
response_dict = json.loads(response.content)
assert response_dict['version_control'] == 'specific_version'
assert response_dict['current_version'] == '1.1'
assert response_dict['target_version'] == '1.0'
assert response_dict['loop_version'] == '1.1'
assert response_dict['local_versions'] == ['1.1']

await asyncio.sleep(11)
response = requests.get(f'http://localhost:{GLOBALS.detector_port}/model_version', timeout=30)
assert response.status_code == 200, response.content
response_dict = json.loads(response.content)
assert response_dict['version_control'] == 'specific_version'
assert response_dict['current_version'] == '1.0'
assert response_dict['target_version'] == '1.0'
assert response_dict['loop_version'] == '1.1'
assert set(response_dict['local_versions']) == set(['1.1', '1.0'])
await await_correct_response({'version_control': 'specific_version',
'current_version': '1.1',
'target_version': '1.0',
'loop_version': '1.1'})

await await_correct_response({'version_control': 'specific_version',
'current_version': '1.0',
'target_version': '1.0',
'loop_version': '1.1',
'local_versions': ['1.1', '1.0']})

response = requests.put(f'http://localhost:{GLOBALS.detector_port}/model_version', data='pause', timeout=30)
assert response.status_code == 200, response.content
await asyncio.sleep(11)
response = requests.get(f'http://localhost:{GLOBALS.detector_port}/model_version', timeout=30)
assert response.status_code == 200, response.content
response_dict = json.loads(response.content)
assert response_dict['version_control'] == 'pause'
assert response_dict['current_version'] == '1.0'
assert response_dict['target_version'] == '1.0'
assert response_dict['loop_version'] == '1.1'
assert set(response_dict['local_versions']) == set(['1.1', '1.0'])

await await_correct_response({'version_control': 'pause',
'current_version': '1.0',
'target_version': '1.0',
'loop_version': '1.1',
'local_versions': ['1.1', '1.0']})

response = requests.put(f'http://localhost:{GLOBALS.detector_port}/model_version', data='follow_loop', timeout=30)
await asyncio.sleep(11)
response = requests.get(f'http://localhost:{GLOBALS.detector_port}/model_version', timeout=30)
assert response.status_code == 200, response.content
response_dict = json.loads(response.content)
assert response_dict['version_control'] == 'follow_loop'
assert response_dict['current_version'] == '1.1'
assert response_dict['target_version'] == '1.1'
assert response_dict['loop_version'] == '1.1'
assert set(response_dict['local_versions']) == set(['1.1', '1.0'])
await await_correct_response({'version_control': 'follow_loop',
'current_version': '1.1',
'target_version': '1.1',
'loop_version': '1.1',
'local_versions': ['1.1', '1.0']})


async def test_rest_outbox_mode(test_detector_node: DetectorNode):
Expand Down
2 changes: 1 addition & 1 deletion learning_loop_node/tests/detector/test_detector_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def mock_save(*args, **kwargs):

expected_save_args = {
'image': raw_image,
'detections': detector_node.detector_logic.detections, # type: ignore
'detections': detector_node.detector_logic.image_metadata, # type: ignore
'tags': ['test_tag'],
'source': 'test_source',
'creation_date': '2024-01-01T00:00:00',
Expand Down
8 changes: 4 additions & 4 deletions learning_loop_node/tests/general/test_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ...data_classes import Context
from ...data_exchanger import DataExchanger
from ...globals import GLOBALS
from ...helpers.misc import delete_corrupt_images
from ...helpers.misc import create_image_folder, create_project_folder, create_training_folder, delete_corrupt_images
from .. import test_helper

# Used by all Nodes
Expand Down Expand Up @@ -77,8 +77,8 @@ async def test_removal_of_corrupted_images(data_exchanger: DataExchanger):


def create_needed_folders(training_uuid: str = 'some_uuid'): # pylint: disable=unused-argument
project_folder = test_helper.create_project_folder(
project_folder = create_project_folder(
Context(organization='zauberzeug', project='pytest_nodelib_general'))
image_folder = test_helper.create_image_folder(project_folder)
training_folder = test_helper.create_training_folder(project_folder, training_uuid)
image_folder = create_image_folder(project_folder)
training_folder = create_training_folder(project_folder, training_uuid)
return project_folder, image_folder, training_folder
5 changes: 2 additions & 3 deletions learning_loop_node/tests/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from glob import glob
from typing import Callable

from ..data_classes import (BoxDetection, ClassificationDetection, Context, ImageMetadata, Point, PointDetection,
from ..data_classes import (BoxDetection, ClassificationDetection, Detections, Point, PointDetection,
SegmentationDetection, Shape)
from ..helpers.misc import create_image_folder, create_project_folder, create_training_folder
from ..loop_communication import LoopCommunicator


Expand Down Expand Up @@ -64,7 +63,7 @@ def _update_attribute_dict(obj: dict, **kwargs) -> None:


def get_dummy_detections():
return ImageMetadata(
return Detections(
box_detections=[
BoxDetection(category_name='some_category_name', x=1, y=2, height=3, width=4,
model_name='some_model', confidence=.42, category_id='some_id')],
Expand Down
6 changes: 3 additions & 3 deletions learning_loop_node/tests/trainer/testing_trainer_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time
from typing import Dict, List, Optional

from ...data_classes import Context, ImageMetadata, ModelInformation, PretrainedModel, TrainingStateData
from ...data_classes import Context, Detections, ModelInformation, PretrainedModel, TrainingStateData
from ...trainer.trainer_logic import TrainerLogic


Expand Down Expand Up @@ -83,8 +83,8 @@ def _can_resume(self) -> bool:
async def _resume(self) -> None:
return await self._start_training_from_base_model()

async def _detect(self, model_information: ModelInformation, images: List[str], model_folder: str) -> List[ImageMetadata]:
detections: List[ImageMetadata] = []
async def _detect(self, model_information: ModelInformation, images: List[str], model_folder: str) -> List[Detections]:
detections: List[Detections] = []
return detections

async def _clear_training_data(self, training_folder: str) -> None:
Expand Down
Loading

0 comments on commit e823122

Please sign in to comment.