import pathlib
import threading
import time
import warnings
from contextlib import contextmanager
from typing import List, Optional, Tuple
import av
import numpy as np
# from line_profiler import profile
from numpy.typing import NDArray
[docs]
def ts_to_index(ts: float, time: NDArray) -> int:
"""
Return the index of the frame whose experimental time is just before (or equal to) `ts`.
Parameters
----------
ts : float
Experimental timestamp to match.
time : NDArray
Array of experimental timestamps, assumed sorted in ascending order,
with one entry per frame.
Returns
-------
idx : int
Index of the frame with time <= `ts`. Clipped to [0, len(time) - 1].
Notes
-----
- If `ts` is smaller than all values in `time`, returns 0.
- If `ts` is greater than all values in `time`, returns `len(time) - 1`.
"""
idx = np.searchsorted(time, ts, side="right") - 1
return np.clip(idx, 0, len(time) - 1)
[docs]
class VideoHandler:
"""Class for getting video frames."""
_get_from_index = False
def __init__(
self,
video_path: str | pathlib.Path,
stream_index: int = 0,
time: Optional[NDArray] = None,
return_frame_array: bool = True,
) -> None:
self.video_path = pathlib.Path(video_path)
self.container = av.open(video_path)
self.stream = self.container.streams.video[stream_index]
self.stream_index = stream_index
self.return_frame_array = return_frame_array
self._running = True
# default to linspace
# TODO : what if number of frames is 0.
if time is None:
self._time_provided = False
n_frames = self.stream.frames
frame_duration = 1 / float(self.stream.average_rate)
self.time = np.linspace(0, frame_duration * n_frames - frame_duration, n_frames)
else:
# TODO : check that number of time point matches number of frames
self._time_provided = True
self.time = np.asarray(time)
# initialize index for last decoded frame
# if sampling of other signals (LFP) is much denser, multiple times the frame
# is unchanged, so cache the idx
self.last_loaded_idx = None
# initialize current frame
self.current_frame: Optional[av.VideoFrame] = None
if self.video_path.suffix == ".mkv":
# mkv time is rounded to 3 digits, at least in the example video
# generated by tests/generate_numbered_video.py
self.round_fn = lambda x: np.round(x, 3)
else:
self.round_fn = lambda x: x
# These will be initialized in the thread once n_frames is known
self.all_pts = None
self.all_times = None
self.key_mask = None
self._i = 0 # write position
self._lock = threading.Lock()
if self.stream.frames and self.stream.frames > 0:
self._index_thread = threading.Thread(target=self._build_index_fixed_size, daemon=True)
else:
self._index_thread = threading.Thread(target=self._build_index_dynamic, daemon=True)
self._index_ready = threading.Event()
self._index_thread.start()
self._keypoint_pts = []
self._pts_keypoint_ready = threading.Event()
self._keypoint_thread = threading.Thread(target=self._extract_keypoints_pts, daemon=True)
self._keypoint_thread.start()
[docs]
def extract_keyframe_times_and_points(
self, video_path: str | pathlib.Path, stream_index: int = 0, first_only=False
) -> Tuple[NDArray, NDArray] | None:
"""
Extract the indices and timestamps of keyframes from a video file.
This function decodes the video while skipping non-keyframes, and records:
- The index of each keyframe in the full video frame sequence
- The "Presentation Time Stamp" to each keyframe.
It is typically intended to run in a background thread during
initialization of a ``VideoHandler``, and supports optimized seeking:
- When the requested frame (based on experimental time) is before the
current playback position, seeking backward is necessary.
- When the requested frame is beyond the next known keyframe, seeking
forward to the closest keyframe is more efficient than decoding all
intermediate frames.
Parameters
----------
video_path : str or pathlib.Path
The path to the video file.
stream_index:
The index of the video stream.
first_only:
If true, return the first keypoint only. Used at initialization.
Returns
-------
keyframe_points : NDArray[float]
The point number of the frame.
keyframe_timestamps : NDArray[float]
The timestamp of the frame.
"""
keyframe_timestamp = []
keyframe_pts = []
with av.open(video_path) as container:
stream = container.streams.video[stream_index]
stream.codec_context.skip_frame = "NONKEY"
frame_index = 0
for frame in container.decode(stream):
if not self._running:
return
keyframe_timestamp.append(frame.time)
keyframe_pts.append(frame.pts)
if first_only:
break
frame_index += 1
return np.asarray(keyframe_pts), np.asarray(keyframe_timestamp, dtype=float)
@contextmanager
def _set_get_from_index(self, value):
"""Context manager for setting the shallow copy flag in a thread safe way."""
old_value = self.__class__._get_from_index
self.__class__._get_from_index = value
try:
yield
finally:
self.__class__._get_from_index = old_value
def _extract_keypoints_pts(self):
try:
with av.open(self.video_path) as container:
stream = container.streams.video[0]
for packet in container.demux(stream):
if not self._running:
return
if packet.is_keyframe:
with self._lock:
self._keypoint_pts.append(packet.pts)
except Exception as e:
# do not block gui
print("Keypoint thread error:", e)
finally:
self._pts_keypoint_ready.set()
def _build_index_fixed_size(self):
try:
with av.open(self.video_path) as container:
stream = container.streams.video[self.stream_index]
n_frames = stream.frames
if not n_frames or n_frames <= 0:
raise ValueError("Cannot determine total number of frames in stream.")
self.all_pts = np.empty(n_frames, dtype=np.int64)
self._i = 0 # Number of valid entries
for packet in container.demux(stream):
if not self._running:
return
for frame in packet.decode():
if self._i >= n_frames:
break
with self._lock:
self.all_pts[self._i] = frame.pts
self._i += 1
except Exception as e:
print("Index thread error:", e)
finally:
self._index_ready.set()
def _build_index_dynamic(self):
try:
with av.open(self.video_path) as container:
if not self._running:
return
stream = container.streams.video[self.stream_index]
pts_list = []
current_index = 0
flush_every = 10 # number of frames over which flushing to all points
for packet in container.demux(stream):
for frame in packet.decode():
if frame.pts is not None:
pts_list.append(frame.pts)
if current_index % flush_every == 1:
with self._lock:
self.all_pts = pts_list
self._i = current_index
current_index += 1
except Exception as e:
print("Index thread error:", e)
finally:
self._index_ready.set()
def _need_seek_call(self, current_frame_pts, target_frame_pts):
with self._lock:
# return if empty list or empty array or not enough frmae
if len(self._keypoint_pts) == 0 or self._keypoint_pts[-1] < target_frame_pts:
return True
# roll back the stream if video is scrolled backwards
if current_frame_pts > target_frame_pts:
return True
# find the closest keypoint pts before a given frame
idx = np.searchsorted(self._keypoint_pts, target_frame_pts, side="right")
closest_keypoint_pts = self._keypoint_pts[max(0, idx - 1)]
# if target_frame_pts is larger than current (and if code
# arrives here, it is, see second return statement),
# then seek forward if there is a future keypoint closest
# to the target.
return closest_keypoint_pts > current_frame_pts
def _get_frame_idx(self, pts: int) -> int:
"""
Get the frame index from the presentation time stamp.
Parameters
----------
pts:
The presentation time stamp of the frame.
Returns
-------
idx:
The frame index corresponding to the given pts.
use_time:
If true, search using presentation time in seconds, otherwise use pts.
"""
# Wait until enough index is available
# Estimate pts from index (using filled index if available)
with self._lock:
done = self.all_pts[min(self._i, len(self.all_pts) - 1)] > pts
if done:
# the pts for this timestamp has been filled
idx = np.searchsorted(self.all_pts, pts, side="right")
use_time = False
else:
# keep going until at least two frames have been decoded by the thread
while True:
with self._lock:
if self._i > 1:
break
time.sleep(0.001)
# use recent history to get the step estimate
with self._lock:
# Linear extrapolation from available pts (use last 10 steps for an estimate)
start, stop = max(self._i - 10, 0), self._i
avg_step = np.mean(np.diff(self.all_pts[start:stop]))
idx = int((pts - self.all_pts[0]) / avg_step)
use_time = True
return idx, use_time
def _get_target_frame_pts(self, idx: int) -> Tuple[int, bool]:
"""
Get the target frame presentation time stamp from frame index.
Parameters
----------
idx:
The frame index.
Returns
-------
target_pts:
The target frame presentation time stamp corresponding to the frame index.
use_time:
If true, search using presentation time in seconds, otherwise use pts.
"""
# Wait until enough index is available
# Estimate pts from index (using filled index if available)
with self._lock:
done = self._i > idx
if done:
# the pts for this timestamp has been filled
target_pts = self.all_pts[idx]
use_time = False
else:
# keep going until at least two frames have been decoded by the thread
while True:
with self._lock:
if self._i > 1:
break
time.sleep(0.001)
# use recent history to get the step estimate
with self._lock:
# Linear extrapolation from available pts (use last 10 steps for an estimate)
start, stop = max(self._i - 10, 0), self._i
avg_step = np.mean(np.diff(self.all_pts[start:stop]))
target_pts = int(self.all_pts[0] + avg_step * idx)
use_time = True
return target_pts, use_time
[docs]
def get_key_frame(self, backward) -> av.VideoFrame | NDArray:
idx = self.last_loaded_idx
if idx is None:
# fallback to safe keypoint
self._pts_keypoint_ready.wait(2.0)
if len(self._keypoint_pts) > 0:
idx = self._get_frame_idx(self._keypoint_pts[0])[0]
else:
idx = 0 # safe fallback
# Get the pts of the last loaded index
target_pts, use_time = self._get_target_frame_pts(idx)
# Seek the next or previous keyframe based on the direction
with self._lock:
delta = max(np.mean(np.diff(self._keypoint_pts[:10])) // 2, 1)
try:
self.container.seek(
int(
target_pts + (-delta if backward else delta)
), # if you're on top of a key frame, seek does not move no matter what
backward=backward,
any_frame=False,
stream=self.stream,
)
except av.error.PermissionError:
# seek backward at the end of the file
self.container.seek(
int(target_pts),
backward=True,
any_frame=False,
stream=self.stream,
)
# Decode the next frame, which should be a keyframe
frame = next(
frame
for packet in self.container.demux(self.stream)
if packet is not None
for frame in packet.decode()
)
self.current_frame = frame
# Get the index of the key frame
self.last_loaded_idx = self._get_frame_idx(frame.pts)[0] - 1
# Return both
return (
self.current_frame.to_ndarray(format="rgb24")[::-1] / 255.0
if self.return_frame_array
else self.current_frame,
self.last_loaded_idx,
)
[docs]
def get(self, ts: float) -> av.VideoFrame | NDArray:
if not self.__class__._get_from_index:
idx = ts_to_index(ts, self.time)
else:
idx = ts
if idx == self.last_loaded_idx:
return (
self.current_frame.to_ndarray(format="rgb24")[::-1] / 255.0
if self.return_frame_array
else self.current_frame
)
target_pts, use_time = self._get_target_frame_pts(idx)
if not hasattr(self.current_frame, "pts") or self._need_seek_call(
self.current_frame.pts, target_pts
):
self.container.seek(
int(target_pts), backward=True, any_frame=False, stream=self.stream
)
# Decode forward from the keypoint until the frame just before (or equal to) target_pts
last_idx, preceding_frame = self._decode_and_check_frames(use_time, target_pts, idx)
if preceding_frame is not None:
self.last_loaded_idx = idx
self.current_frame = preceding_frame
return (
self.current_frame.to_ndarray(format="rgb24")[::-1] / 255.0
if self.return_frame_array
else self.current_frame
)
def _frame_iterator(self, fall_back_pts: int | None):
"""
Safe frame iterator.
Iterate frames from current stream location. If End-of-File error is
hit, seek to pts and iterate over frames from there.
"""
try:
for packet in self.container.demux(self.stream):
if packet is None:
continue
for frame in packet.decode():
if frame.pts is None:
continue
yield frame
except av.error.EOFError as e:
if fall_back_pts is None:
raise e
self.container.seek(
int(fall_back_pts), backward=True, any_frame=False, stream=self.stream
)
yield from self._frame_iterator(None)
def _decode_and_check_frames(self, use_time: bool, target_pts: int, idx: int):
"""Decode from stream."""
preceding_frame = None
last_idx = self.last_loaded_idx
frame_duration = 1 / float(self.stream.average_rate)
time_threshold = self.round_fn(idx * frame_duration)
for frame in self._frame_iterator(target_pts):
if frame.pts is None:
continue
if (not use_time and frame.pts > target_pts) or (
use_time and frame.time > time_threshold
):
last_idx = idx
current_frame = preceding_frame or frame
return last_idx, current_frame
elif (not use_time and frame.pts == target_pts) or (
use_time and frame.time == time_threshold
):
last_idx = idx
current_frame = frame
return last_idx, current_frame
preceding_frame = frame
return last_idx, preceding_frame
@property
def shape(self):
if (
self._time_provided
): # TODO maybe check what is the actual number of frames decoded and throw a warning
return len(self.time), self.stream.width, self.stream.height
has_frames = hasattr(self.stream, "frames") and self.stream.frames > 0
is_done_unpacking = self._index_ready.is_set()
if not has_frames and not is_done_unpacking:
warnings.warn(
message="Video ``shape``, which corresponds to the number of frames, is being "
"calculated runtime and will be updated.",
stacklevel=2,
)
return (
(len(self.time), self.stream.width, self.stream.height)
if has_frames
else (len(self.all_pts), self.stream.width, self.stream.height)
)
@property
def index(self):
if self._time_provided:
return self.time
else:
has_frames = hasattr(self.stream, "frames") and self.stream.frames > 0
is_done_unpacking = self._index_ready.is_set()
if not has_frames and not is_done_unpacking:
warnings.warn(
message="Video ``shape``, which corresponds to the number of frames, is being "
"calculated runtime and will be updated.",
stacklevel=2,
)
return self.time
@property
def t(self):
return self.time
[docs]
def close(self):
"""Close the video stream."""
self._running = False
if self._index_thread.is_alive():
self._index_thread.join(timeout=1) # Be conservative, don’t block forever
if self._keypoint_thread.is_alive():
self._keypoint_thread.join(timeout=1)
try:
self.container.close()
except Exception:
print("VideoHandler failed to close the video stream.")
finally:
# dropping refs to fully close av.InputContainer
self.container = None
self.stream = None
def _wait_for_index(self, timeout=2.0):
"""Wait up to timeout.
For debugging purposes, or testing, make sure that the
threads are completed.
"""
self._index_ready.wait(timeout)
self._pts_keypoint_ready.wait(timeout)
[docs]
def get_slice(self, start: float, end: float = None):
# TODO check start and end are sorted
start = ts_to_index(start, self.time)
if end:
end = ts_to_index(end, self.time)
return slice(start, end)
else:
return slice(start, start + 1)
def _append_frame(self, frames, idx, frame):
if self.return_frame_array:
frames[idx] = frame.to_ndarray(format="rgb24")[::-1] / 255.0
else:
frames.append(frame)
def _decode_multiple(
self,
target_pts,
idx_start: int,
idx_end: int,
step: int = 1,
) -> Tuple[int, List[av.VideoFrame] | NDArray, av.VideoFrame]:
effective_end = min(idx_end, self.shape[0])
indices = np.arange(idx_start, effective_end, step)
num_frames = len(indices)
time_threshold_all = self.round_fn(indices)
if self.return_frame_array:
frames = np.empty(
(num_frames, self.shape[2], self.shape[1], 3),
dtype=np.float32,
)
else:
frames = []
collected = 0
# initialize current frame
if self.current_frame is None:
self.get(0)
preceding_frame = self.current_frame
go_to_next_packet = False
while collected < num_frames:
if not go_to_next_packet:
target_pts, use_time = self._get_target_frame_pts(indices[collected])
# First frame shortcut
if collected == 0 and hasattr(self.current_frame, "pts"):
if self.current_frame.pts == target_pts:
self._append_frame(frames, collected, self.current_frame)
collected = 1
continue
elif self.current_frame.pts > target_pts:
self.current_frame = None
self.container.seek(
int(target_pts),
backward=True,
any_frame=False,
stream=self.stream,
)
go_to_next_packet = True
if not go_to_next_packet and self._need_seek_call(preceding_frame.pts, target_pts):
self.container.seek(
int(target_pts),
backward=True,
any_frame=False,
stream=self.stream,
)
packet = next(self.container.demux(self.stream))
try:
decoded = packet.decode()
while len(decoded) == 0:
decoded = packet.decode()
except av.error.EOFError:
# end of the video, rewind
break
for frame in decoded:
if frame.pts is None:
continue
time_threshold = time_threshold_all[collected]
found_next = (
(frame.pts > target_pts) if not use_time else (frame.time > time_threshold)
)
found_current = (
(frame.pts == target_pts) if not use_time else (frame.time == time_threshold)
)
if found_next:
self._append_frame(frames, collected, preceding_frame)
collected += 1
go_to_next_packet = False
elif found_current:
self._append_frame(frames, collected, frame)
collected += 1
go_to_next_packet = False
else:
go_to_next_packet = True
preceding_frame = frame
return indices[-1], frames, frame
def __getitem__(self, idx: slice | int):
"""
Get item for video frame.
Gets one or more frames from a video.
Parameters
----------
idx:
The index for slicing.
Returns
-------
frame:
A video frame.
"""
if isinstance(idx, slice):
# Fill in missing slice components
start = idx.start or 0
if start >= self.shape[0]:
if self.return_frame_array:
return np.empty((0, self.shape[2], self.shape[1], 3))
else:
return []
stop = idx.stop if idx.stop is not None else self.shape[0]
step = idx.step if idx.step is not None else 1
# convert negative vals
start = start if start >= 0 else start + self.shape[0]
start = max(0, min(start, self.shape[0]))
stop = stop + self.shape[0] if stop < 0 else stop
stop = max(0, min(stop, self.shape[0]))
# revert slice if negative step
revert = step < 0
step = abs(step)
if (stop - start) // step > 1:
target_pts, use_time = self._get_target_frame_pts(start)
if not hasattr(self.current_frame, "pts") or self._need_seek_call(
self.current_frame.pts, target_pts
):
self.container.seek(
int(target_pts), backward=True, any_frame=False, stream=self.stream
)
frame_idx, frames, last_frame = self._decode_multiple(
target_pts, start, stop, step=step
)
# update current decoded frame
if len(frames):
self.last_loaded_idx = frame_idx
self.current_frame = last_frame
return frames if not revert else frames[::-1]
# Default case: single index
with self._set_get_from_index(True):
# TODO CHeck borders
idx_start = idx if not hasattr(idx, "start") else idx.start
idx_start = idx_start if idx_start >= 0 else self.shape[0] + idx_start
frame = self.get(idx_start)
if isinstance(idx, slice):
frame = np.expand_dims(frame, axis=0)
return frame
def __len__(self):
return self.shape[0]
# context protocol
# (with VideoHandler(path) as video ensure closing)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()