Skip to content

Newmetric: VMAF #2991

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 35 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
96832e1
add new requirements
SkafteNicki Mar 5, 2025
e6f8fe2
initial structure
SkafteNicki Mar 5, 2025
fe9cd7a
docs
SkafteNicki Mar 5, 2025
16b0639
test structure
SkafteNicki Mar 5, 2025
06138f6
add starting point of implementation
SkafteNicki Mar 5, 2025
cf7a8ba
changelog
SkafteNicki Mar 5, 2025
6b4397f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 5, 2025
792f67e
improve implementations
SkafteNicki Mar 5, 2025
ce56ce4
merge
SkafteNicki Mar 5, 2025
1c41c8c
Merge branch 'master' into newmetric/vmaf
SkafteNicki Mar 5, 2025
fb5d7c3
Merge branch 'master' into newmetric/vmaf
Borda Mar 7, 2025
6b4daca
fix init and import
SkafteNicki Mar 31, 2025
461d70d
improve functional implementation
SkafteNicki Mar 31, 2025
647bb9c
improve src
SkafteNicki Mar 31, 2025
ad2a46a
add testing
SkafteNicki Mar 31, 2025
43322c2
improve tests
SkafteNicki Mar 31, 2025
1072c61
tests
SkafteNicki Mar 31, 2025
4c48e08
Merge branch 'master' into newmetric/vmaf
Borda Apr 1, 2025
dcfa027
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 1, 2025
6c11a89
smaller changes
SkafteNicki Apr 9, 2025
6f2a4be
Merge branch 'master' into newmetric/vmaf
Borda Apr 16, 2025
a937f26
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 16, 2025
2bef8a4
Merge branch 'master' into newmetric/vmaf
Borda Apr 16, 2025
69ff57e
Merge branch 'master' into newmetric/vmaf
Borda Apr 25, 2025
1159b0d
Merge branch 'master' into newmetric/vmaf
SkafteNicki Apr 29, 2025
5733e82
fix functional implementation
SkafteNicki Apr 29, 2025
53f2bc5
fix modular implementation
SkafteNicki Apr 29, 2025
c873062
redo functional implementation
SkafteNicki Apr 29, 2025
e5500b3
fix modular implementation
SkafteNicki Apr 29, 2025
3358bd5
fix modular implementation
SkafteNicki Apr 29, 2025
c1305c2
unittests
SkafteNicki Apr 29, 2025
cf5205d
Merge branch 'master' into newmetric/vmaf
SkafteNicki Apr 29, 2025
28b4def
fix typing
SkafteNicki Apr 29, 2025
1bdc2e3
Merge branch 'newmetric/vmaf' of https://github.com/Lightning-AI/torc…
SkafteNicki Apr 29, 2025
048d611
Merge branch 'master' into newmetric/vmaf
SkafteNicki Apr 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Added `VMAF` metric to new video domain ([#2991](https://github.com/Lightning-AI/torchmetrics/pull/2991))


- Added CRPS in regression domain ([#3024](https://github.com/Lightning-AI/torchmetrics/pull/3024))


Expand Down
8 changes: 8 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,14 @@ Or directly from conda

text/*

.. toctree::
:maxdepth: 2
:name: video
:caption: Video
:glob:

video/*

.. toctree::
:maxdepth: 2
:name: wrappers
Expand Down
22 changes: 22 additions & 0 deletions docs/source/video/vmaf.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
.. customcarditem::
:header: Video Multi-Method Assessment Fusion (VMAF)
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg
:tags: Segmentation

.. include:: ../links.rst

###########################################
Video Multi-Method Assessment Fusion (VMAF)
###########################################

Module Interface
________________

.. autoclass:: torchmetrics.video.VideoMultiMethodAssessmentFusion
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.segmentation.video_multi_method_assessment_fusion
:noindex:
1 change: 1 addition & 0 deletions requirements/_devel.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
-r image.txt
-r text.txt
-r multimodal.txt
-r video.txt
-r visual.txt

# add extra testing
Expand Down
2 changes: 2 additions & 0 deletions requirements/video.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
vmaf_torch @ git+https://github.com/alvitrioliks/VMAF-torch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we ask them to push the package to Pip or we would need to do it..

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

einops
21 changes: 21 additions & 0 deletions src/torchmetrics/functional/video/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.utilities.imports import _TORCH_VMAF_AVAILABLE

if _TORCH_VMAF_AVAILABLE:
from torchmetrics.functional.video.vmaf import video_multi_method_assessment_fusion

__all__ = ["video_multi_method_assessment_fusion"]
else:
__all__ = []
135 changes: 135 additions & 0 deletions src/torchmetrics/functional/video/vmaf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Union

import torch
import vmaf_torch
from einops import rearrange
from torch import Tensor

from torchmetrics.utilities.imports import _TORCH_VMAF_AVAILABLE


def calculate_luma(video: Tensor) -> Tensor:
"""Calculate the luma component of a video tensor."""
r = video[:, 0, :, :, :]
g = video[:, 1, :, :, :]
b = video[:, 2, :, :, :]
return (0.299 * r + 0.587 * g + 0.114 * b).unsqueeze(1) * 255 # [0, 1] -> [0, 255]


def video_multi_method_assessment_fusion(
preds: Tensor,
target: Tensor,
features: bool = False,
) -> Union[Tensor, Dict[str, Tensor]]:
"""Calculates Video Multi-Method Assessment Fusion (VMAF) metric.

VMAF is a full-reference video quality assessment algorithm that combines multiple quality assessment features
such as detail loss, motion, and contrast using a machine learning model to predict human perception of video
quality more accurately than traditional metrics like PSNR or SSIM.

The metric works by:
1. Converting input videos to luma component (grayscale)
2. Computing multiple elementary features:
- Additive Detail Measure (ADM): Evaluates detail preservation at different scales
- Visual Information Fidelity (VIF): Measures preservation of visual information across frequency bands
- Motion: Quantifies the amount of motion in the video
3. Combining these features using a trained SVM model to predict quality

.. note::
This implementation requires you to have vmaf-torch installed: https://github.com/alvitrioliks/VMAF-torch.
Install either by cloning the repository and running `pip install .` or with `pip install torchmetrics[video]`.

Args:
preds: Video tensor of shape (batch, channels, frames, height, width). Expected to be in RGB format
with values in range [0, 1].
target: Video tensor of shape (batch, channels, frames, height, width). Expected to be in RGB format
with values in range [0, 1].
features: If True, all the elementary features (ADM, VIF, motion) are returned along with the VMAF score in
a dictionary. This corresponds to the output you would get from the VMAF command line tool with the `--csv`
option enabled. If False, only the VMAF score is returned as a tensor.

Returns:
If `features` is False, returns a tensor with shape (batch, frame) of VMAF score for each frame in
each video. Higher scores indicate better quality, with typical values ranging from 0 to 100.

If `features` is True, returns a dictionary where each value is a (batch, frame) tensor of the
corresponding feature. The keys are:
- 'integer_motion2': Integer motion feature
- 'integer_motion': Integer motion feature
- 'integer_adm2': Integer ADM feature
- 'integer_adm_scale0': Integer ADM feature at scale 0
- 'integer_adm_scale1': Integer ADM feature at scale 1
- 'integer_adm_scale2': Integer ADM feature at scale 2
- 'integer_adm_scale3': Integer ADM feature at scale 3
- 'integer_vif_scale0': Integer VIF feature at scale 0
- 'integer_vif_scale1': Integer VIF feature at scale 1
- 'integer_vif_scale2': Integer VIF feature at scale 2
- 'integer_vif_scale3': Integer VIF feature at scale 3
- 'vmaf': VMAF score for each frame in each video

Example:
>>> import torch
>>> from torchmetrics.functional.video import video_multi_method_assessment_fusion
>>> preds = torch.rand(2, 3, 10, 32, 32) # 2 videos, 3 channels, 10 frames, 32x32 resolution
>>> target = torch.rand(2, 3, 10, 32, 32) # 2 videos, 3 channels, 10 frames, 32x32 resolution
>>> video_multi_method_assessment_fusion(preds, target)
tensor([[ 7.0141, 17.4276, 15.1429, 14.9831, 19.3378, 12.5638, 13.9680, 13.4165, 17.9314, 16.0849],
[ 7.2760, 11.1951, 15.3990, 13.5877, 15.1370, 18.4508, 17.5596, 18.6859, 12.9309, 15.1975]])
>>> vmaf_dict = video_multi_method_assessment_fusion(preds, target, features=True)
>>> # show a couple of features, more features are available
>>> vmaf_dict['vmaf']
tensor([[ 7.0141, 17.4276, 15.1429, 14.9831, 19.3378, 12.5638, 13.9680, 13.4165, 17.9314, 16.0849],
[ 7.2760, 11.1951, 15.3990, 13.5877, 15.1370, 18.4508, 17.5596, 18.6859, 12.9309, 15.1975]])
>>> vmaf_dict['integer_adm2']
tensor([[0.3980, 0.4449, 0.3650, 0.4070, 0.4504, 0.4817, 0.4489, 0.3870, 0.4140, 0.4029],
[0.4134, 0.4216, 0.4050, 0.4138, 0.4209, 0.4378, 0.4333, 0.4358, 0.3846, 0.4168]])
>>> vmaf_dict['integer_vif_scale0']
tensor([[0.0004, 0.0002, 0.0004, 0.0007, 0.0010, 0.0007, 0.0011, 0.0001, 0.0006, 0.0003],
[0.0004, 0.0002, 0.0010, 0.0005, 0.0012, 0.0010, 0.0004, 0.0002, 0.0010, 0.0009]])

"""
if not _TORCH_VMAF_AVAILABLE:
raise RuntimeError("vmaf-torch is not installed. Please install with `pip install torchmetrics[video]`.")
b = preds.shape[0]
orig_dtype, device = preds.dtype, preds.device
preds_luma = calculate_luma(preds)
target_luma = calculate_luma(target)

vmaf = vmaf_torch.VMAF().to(device)

# we need to compute the model for each video separately
if not features:
scores = [
vmaf.compute_vmaf_score(
rearrange(target_luma[video], "c f h w -> f c h w"), rearrange(preds_luma[video], "c f h w -> f c h w")
)
for video in range(b)
]
return torch.cat(scores, dim=1).t().to(orig_dtype)
import pandas as pd # pandas is installed as a dependency of vmaf-torch

scores_and_features = [
vmaf.table(
rearrange(target_luma[video], "c f h w -> f c h w"), rearrange(preds_luma[video], "c f h w -> f c h w")
)
for video in range(b)
]
dfs = [scores_and_features[video].apply(pd.to_numeric, errors="coerce") for video in range(b)]
result = [
{col: torch.tensor(dfs[video][col].values, dtype=orig_dtype) for col in dfs[video].columns if col != "Frame"}
for video in range(b)
]
return {col: torch.stack([result[video][col] for video in range(b)]) for col in result[0]}
1 change: 1 addition & 0 deletions src/torchmetrics/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,5 @@
_TORCH_LINEAR_ASSIGNMENT_AVAILABLE = RequirementCache("torch_linear_assignment")
_AEON_AVAILABLE = RequirementCache("aeon")
_PYTDC_AVAILABLE = RequirementCache("pyTDC")
_TORCH_VMAF_AVAILABLE = RequirementCache("vmaf_torch")
_LATEX_AVAILABLE: bool = shutil.which("latex") is not None
21 changes: 21 additions & 0 deletions src/torchmetrics/video/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.utilities.imports import _TORCH_VMAF_AVAILABLE

if _TORCH_VMAF_AVAILABLE:
from torchmetrics.video.vmaf import VideoMultiMethodAssessmentFusion

__all__ = ["VideoMultiMethodAssessmentFusion"]
else:
__all__ = []
Loading
Loading