Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding multiple processes to task_proc - TRON-2192 #219

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 50 additions & 24 deletions task_processing/plugins/kubernetes/kubernetes_pod_executor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
import multiprocessing
import queue
import threading
import time
from multiprocessing import JoinableQueue
from queue import Queue
from typing import Collection
from typing import Optional
Expand Down Expand Up @@ -51,8 +52,8 @@

logger = logging.getLogger(__name__)

POD_WATCH_THREAD_JOIN_TIMEOUT_S = 1.0
POD_EVENT_THREAD_JOIN_TIMEOUT_S = 1.0
POD_WATCH_PROCESS_JOIN_TIMEOUT_S = 1.0
POD_EVENT_PROCESS_JOIN_TIMEOUT_S = 1.0
QUEUE_GET_TIMEOUT_S = 0.5
SUPPORTED_POD_MODIFIED_EVENT_PHASES = {
"Failed",
Expand Down Expand Up @@ -100,7 +101,7 @@ def __init__(
self.stopping = False
self.task_metadata: PMap[str, KubernetesTaskMetadata] = pmap()

self.task_metadata_lock = threading.RLock()
self.task_metadata_lock = multiprocessing.RLock()
if task_configs:
for task_config in task_configs:
self._initialize_existing_task(task_config)
Expand All @@ -110,33 +111,33 @@ def __init__(
# and we've opted to not do that processing in the Pod event watcher thread so as to keep
# that logic for the threads that operate on them as simple as possible and to make it
# possible to cleanly shutdown both of these.
self.pending_events: "Queue[PodEvent]" = Queue()
self.event_queue: "Queue[Event]" = Queue()
self.pending_events: "JoinableQueue[PodEvent]" = JoinableQueue()
self.event_queue: "JoinableQueue[Event]" = JoinableQueue()

# TODO(TASKPROC-243): keep track of resourceVersion so that we can continue event processing
# from where we left off on restarts
self.pod_event_watch_threads = []
self.pod_event_watch_processes = []
self.watches = []
for kube_client in [self.kube_client] + self.watcher_kube_clients:
watch = kube_watch.Watch()
pod_event_watch_thread = threading.Thread(
pod_event_watch_process = multiprocessing.Process(
target=self._pod_event_watch_loop,
args=(kube_client, watch),
# ideally this wouldn't be a daemon thread, but a watch.Watch() only checks
# ideally this wouldn't be a daemon process, but a watch.Watch() only checks
# if it should stop after receiving an event - and it's possible that we
# have periods with no events so instead we'll attempt to stop the watch
# and then join() with a small timeout to make sure that, if we shutdown
# with the thread alive, we did not drop any events
# with the process alive, we did not drop any events
daemon=True,
)
pod_event_watch_thread.start()
self.pod_event_watch_threads.append(pod_event_watch_thread)
pod_event_watch_process.start()
self.pod_event_watch_processes.append(pod_event_watch_process)
self.watches.append(watch)

self.pending_event_processing_thread = threading.Thread(
self.pending_event_processing_process = multiprocessing.Process(
target=self._pending_event_processing_loop,
)
self.pending_event_processing_thread.start()
self.pending_event_processing_process.start()

def _initialize_existing_task(self, task_config: KubernetesTaskConfig) -> None:
"""Generates task_metadata in UNKNOWN state for an existing KubernetesTaskConfig.
Expand Down Expand Up @@ -427,9 +428,18 @@ def _pending_event_processing_loop(self) -> None:
"""
logger.debug("Starting Pod event processing.")
event = None
while not self.stopping or not self.pending_events.empty():
while True:
try:
event = self.pending_events.get(timeout=QUEUE_GET_TIMEOUT_S)
if event["type"] == "STOP":
logger.debug("Received a STOP event - stopping processing.")
try:
self.pending_events.task_done()
except ValueError:
logger.error(
"task_done() called on pending events queue too many times!"
)
break
self._process_pod_event(event)
except queue.Empty:
logger.debug(
Expand Down Expand Up @@ -699,33 +709,49 @@ def kill(self, task_id: str) -> bool:
return terminated

def stop(self) -> None:
logger.debug("Preparing to stop all KubernetesPodExecutor threads.")
logger.debug("Preparing to stop all KubernetesPodExecutor processes.")
self.stopping = True

logger.debug("Signaling Pod event Watch to stop streaming events...")
# make sure that we've stopped watching for events before calling join() - otherwise,
# join() will block until we hit the configured timeout (or forever with no timeout).
for watch in self.watches:
watch.stop()

# Add a STOP event to the queue below after stopping the watch to ensure
# no events will be added after the STOP event
stop_event = PodEvent(type="STOP", object=None, raw_object={})
self.pending_events.put(stop_event)

# timeout arbitrarily chosen - we mostly just want to make sure that we have a small
# grace period to flush the current event to the pending_events queue as well as
# any other clean-up - it's possible that after this join() the thread is still alive
# any other clean-up - it's possible that after this join() the process is still alive
# but in that case we can be reasonably sure that we're not dropping any data.
for pod_event_watch_thread in self.pod_event_watch_threads:
pod_event_watch_thread.join(timeout=POD_WATCH_THREAD_JOIN_TIMEOUT_S)
for pod_event_watch_process in self.pod_event_watch_processes:
pod_event_watch_process.join(timeout=POD_WATCH_PROCESS_JOIN_TIMEOUT_S)

logger.debug("Waiting for all pending PodEvents to be processed...")
# once we've stopped updating the pending events queue, we then wait until we're done
# processing any events we've received - this will wait until task_done() has been
# called for every item placed in this queue
# since we stopped the watch above, we don't expect any more events to be added to the queue
# this ensure that we're not stuck due to the stop event, if it wasn't processed by the _pending_event_processing_loop loop
if (
self.pending_events.qsize() == 1
and self.pending_events.get(timeout=QUEUE_GET_TIMEOUT_S)["type"] == "STOP"
):
try:
self.pending_events.task_done()
except ValueError:
logger.error(
"task_done() called on pending events queue too many times!"
)
self.pending_events.join()
logger.debug("All pending PodEvents have been processed.")
# and then give ourselves time to do any post-stop cleanup
self.pending_event_processing_thread.join(
timeout=POD_EVENT_THREAD_JOIN_TIMEOUT_S
self.pending_event_processing_process.join(
timeout=POD_EVENT_PROCESS_JOIN_TIMEOUT_S
)

logger.debug("Done stopping KubernetesPodExecutor!")

def get_event_queue(self) -> "Queue[Event]":
def get_event_queue(self) -> "JoinableQueue[Event]":
return self.event_queue
3 changes: 2 additions & 1 deletion task_processing/plugins/kubernetes/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ class NodeAffinity(TypedDict):


class PodEvent(TypedDict):
# there are only 3 possible types for Pod events: ADDED, DELETED, MODIFIED
# there are only 4 possible types for Pod events: ADDED, DELETED, MODIFIED or STOP
# STOP is a custom type that we use to signal STOP to all KubernetesPodExecutor processes
# XXX: this should be typed as Literal["ADDED", "DELETED", "MODIFIED"] once we drop support
# for older Python versions
type: str
Expand Down
7 changes: 7 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import multiprocessing
import threading

import mock
Expand All @@ -14,3 +15,9 @@ def mock_sleep():
def mock_Thread():
with mock.patch.object(threading, "Thread") as mock_Thread:
yield mock_Thread


@pytest.fixture
def mock_Process():
with mock.patch.object(multiprocessing, "Process") as mock_Process:
yield mock_Process
80 changes: 50 additions & 30 deletions tests/unit/plugins/kubernetes/kubernetes_pod_executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from kubernetes.client import V1Pod
from kubernetes.client import V1PodSecurityContext
from kubernetes.client import V1PodSpec
from kubernetes.client import V1PodStatus
from kubernetes.client import V1ProjectedVolumeSource
from kubernetes.client import V1ResourceRequirements
from kubernetes.client import V1SecurityContext
Expand All @@ -38,7 +39,7 @@


@pytest.fixture
def k8s_executor(mock_Thread):
def k8s_executor(mock_Process):
with mock.patch(
"task_processing.plugins.kubernetes.kube_client.kube_config.load_kube_config",
autospec=True,
Expand All @@ -53,7 +54,7 @@ def k8s_executor(mock_Thread):


@pytest.fixture
def k8s_executor_with_watcher_clusters(mock_Thread):
def k8s_executor_with_watcher_clusters(mock_Process):
with mock.patch(
"task_processing.plugins.kubernetes.kube_client.kube_config.load_kube_config",
autospec=True,
Expand Down Expand Up @@ -87,7 +88,7 @@ def mock_task_configs():


@pytest.fixture
def k8s_executor_with_tasks(mock_Thread, mock_task_configs):
def k8s_executor_with_tasks(mock_Process, mock_task_configs):
with mock.patch(
"task_processing.plugins.kubernetes.kube_client.kube_config.load_kube_config",
autospec=True,
Expand All @@ -105,13 +106,13 @@ def k8s_executor_with_tasks(mock_Thread, mock_task_configs):


def test_init_watch_setup(k8s_executor):
assert len(k8s_executor.watches) == len(k8s_executor.pod_event_watch_threads) == 1
assert len(k8s_executor.watches) == len(k8s_executor.pod_event_watch_processes) == 1


def test_init_watch_setup_multicluster(k8s_executor_with_watcher_clusters):
assert (
len(k8s_executor_with_watcher_clusters.watches)
== len(k8s_executor_with_watcher_clusters.pod_event_watch_threads)
== len(k8s_executor_with_watcher_clusters.pod_event_watch_processes)
== 2
)

Expand Down Expand Up @@ -697,15 +698,18 @@ def test_process_event_enqueues_task_processing_events_pending_to_running(k8s_ex
mock_pod.metadata.name = "test.1234"
mock_pod.status.phase = "Running"
mock_pod.spec.node_name = "node-1-2-3-4"
task_config = KubernetesTaskConfig(
image="test", command="test", uuid="uuid", name="pod--name"
)
mock_event = PodEvent(
type="MODIFIED",
object=mock_pod,
raw_object=mock.Mock(),
raw_object={},
)
k8s_executor.task_metadata = pmap(
{
mock_pod.metadata.name: KubernetesTaskMetadata(
task_config=mock.Mock(spec=KubernetesTaskConfig),
task_config=task_config,
task_state=KubernetesTaskState.TASK_PENDING,
task_state_history=v(),
)
Expand Down Expand Up @@ -736,15 +740,18 @@ def test_process_event_enqueues_task_processing_events_running_to_terminal(
mock_pod.metadata.name = "test.1234"
mock_pod.status.phase = phase
mock_pod.spec.node_name = "node-1-2-3-4"
task_config = KubernetesTaskConfig(
image="test", command="test", uuid="uuid", name="pod--name"
)
mock_event = PodEvent(
type="MODIFIED",
object=mock_pod,
raw_object=mock.Mock(),
raw_object={},
)
k8s_executor.task_metadata = pmap(
{
mock_pod.metadata.name: KubernetesTaskMetadata(
task_config=mock.Mock(spec=KubernetesTaskConfig),
task_config=task_config,
task_state=KubernetesTaskState.TASK_RUNNING,
task_state_history=v(),
)
Expand Down Expand Up @@ -779,7 +786,7 @@ def test_process_event_enqueues_task_processing_events_no_state_transition(
mock_event = PodEvent(
type="MODIFIED",
object=mock_pod,
raw_object=mock.Mock(),
raw_object={},
)
k8s_executor.task_metadata = pmap(
{
Expand Down Expand Up @@ -807,15 +814,28 @@ def test_process_event_enqueues_task_processing_events_no_state_transition(
def test_pending_event_processing_loop_processes_remaining_events_after_stop(
k8s_executor,
):
# Create a V1Pod object to use for testing multiprocess instead of mock.Mock() as
# it is not pickleable
test_pod = V1Pod(
metadata=V1ObjectMeta(
name="test-pod",
namespace="task_processing_tests",
)
)
k8s_executor.pending_events.put(
PodEvent(
type="ADDED",
object=mock.Mock(),
raw_object=mock.Mock(),
object=test_pod,
raw_object={},
)
)
k8s_executor.pending_events.put(
PodEvent(
type="STOP",
object=None,
raw_object={},
)
)
k8s_executor.stopping = True

with mock.patch.object(
k8s_executor,
"_process_pod_event",
Expand All @@ -835,15 +855,18 @@ def test_process_event_enqueues_task_processing_events_deleted(
mock_pod.status.phase = "Running"
mock_pod.status.host_ip = "1.2.3.4"
mock_pod.spec.node_name = "kubenode"
task_config = KubernetesTaskConfig(
image="test", command="test", uuid="uuid", name="pod--name"
)
mock_event = PodEvent(
type="DELETED",
object=mock_pod,
raw_object=mock.Mock(),
raw_object={},
)
k8s_executor.task_metadata = pmap(
{
mock_pod.metadata.name: KubernetesTaskMetadata(
task_config=mock.Mock(spec=KubernetesTaskConfig),
task_config=task_config,
task_state=KubernetesTaskState.TASK_RUNNING,
task_state_history=v(),
)
Expand All @@ -870,14 +893,13 @@ def test_initial_task_metadata(k8s_executor_with_tasks):
def test_reconcile_missing_pod(
k8s_executor,
):
task_config = mock.Mock(spec=KubernetesTaskConfig)
task_config.pod_name = "pod--name.uuid"
task_config.name = "job-name"

task_config = KubernetesTaskConfig(
image="test", command="test", uuid="uuid", name="pod--name"
)
k8s_executor.task_metadata = pmap(
{
task_config.pod_name: KubernetesTaskMetadata(
task_config=mock.Mock(spec=KubernetesTaskConfig),
task_config=task_config,
task_state=KubernetesTaskState.TASK_UNKNOWN,
task_state_history=v(),
)
Expand All @@ -899,14 +921,13 @@ def test_reconcile_missing_pod(
def test_reconcile_multicluster(
k8s_executor_with_watcher_clusters,
):
task_config = mock.Mock(spec=KubernetesTaskConfig)
task_config.pod_name = "pod--name.uuid"
task_config.name = "job-name"

task_config = KubernetesTaskConfig(
image="test", command="test", uuid="uuid", name="pod--name"
)
k8s_executor_with_watcher_clusters.task_metadata = pmap(
{
task_config.pod_name: KubernetesTaskMetadata(
task_config=mock.Mock(spec=KubernetesTaskConfig),
task_config=task_config,
task_state=KubernetesTaskState.TASK_UNKNOWN,
task_state_history=v(),
)
Expand Down Expand Up @@ -968,10 +989,9 @@ def test_reconcile_existing_pods(k8s_executor, mock_task_configs):
def test_reconcile_api_error(
k8s_executor,
):
task_config = mock.Mock(spec=KubernetesTaskConfig)
task_config.pod_name = "pod--name.uuid"
task_config.name = "job-name"

task_config = KubernetesTaskConfig(
image="test", command="test", uuid="uuid", name="pod--name"
)
with mock.patch.object(
k8s_executor, "kube_client", autospec=True
) as mock_kube_client:
Expand Down
Loading