From 96832e194051a9e2c6f6676e90c400510480b642 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 5 Mar 2025 14:53:15 +0100 Subject: [PATCH 01/24] add new requirements --- requirements/_devel.txt | 1 + requirements/video.txt | 1 + 2 files changed, 2 insertions(+) create mode 100644 requirements/video.txt diff --git a/requirements/_devel.txt b/requirements/_devel.txt index 21d2b69e614..61d0d34eeb6 100644 --- a/requirements/_devel.txt +++ b/requirements/_devel.txt @@ -10,6 +10,7 @@ -r image.txt -r text.txt -r multimodal.txt +-r video.txt -r visual.txt # add extra testing diff --git a/requirements/video.txt b/requirements/video.txt new file mode 100644 index 00000000000..bec31c57bfa --- /dev/null +++ b/requirements/video.txt @@ -0,0 +1 @@ +vmaf_torch @ git+https://github.com/alvitrioliks/VMAF-torch From e6f8fe294959087ee874c47c48c687fa8995bddd Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 5 Mar 2025 14:54:22 +0100 Subject: [PATCH 02/24] initial structure --- src/torchmetrics/functional/video/__init__.py | 13 ++++++ src/torchmetrics/functional/video/vmaf.py | 13 ++++++ src/torchmetrics/video/__init__.py | 13 ++++++ src/torchmetrics/video/vmaf.py | 40 +++++++++++++++++++ 4 files changed, 79 insertions(+) create mode 100644 src/torchmetrics/functional/video/__init__.py create mode 100644 src/torchmetrics/functional/video/vmaf.py create mode 100644 src/torchmetrics/video/__init__.py create mode 100644 src/torchmetrics/video/vmaf.py diff --git a/src/torchmetrics/functional/video/__init__.py b/src/torchmetrics/functional/video/__init__.py new file mode 100644 index 00000000000..7f2988bb312 --- /dev/null +++ b/src/torchmetrics/functional/video/__init__.py @@ -0,0 +1,13 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/src/torchmetrics/functional/video/vmaf.py b/src/torchmetrics/functional/video/vmaf.py new file mode 100644 index 00000000000..7f2988bb312 --- /dev/null +++ b/src/torchmetrics/functional/video/vmaf.py @@ -0,0 +1,13 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/src/torchmetrics/video/__init__.py b/src/torchmetrics/video/__init__.py new file mode 100644 index 00000000000..7f2988bb312 --- /dev/null +++ b/src/torchmetrics/video/__init__.py @@ -0,0 +1,13 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/src/torchmetrics/video/vmaf.py b/src/torchmetrics/video/vmaf.py new file mode 100644 index 00000000000..f87003251cc --- /dev/null +++ b/src/torchmetrics/video/vmaf.py @@ -0,0 +1,40 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torch import Tensor +from torchmetrics.metric import Metric +from vmaf_torch.vmaf import VMAF as VMAF_torch + +class VMAF(Metric): + """ + + + + .. note:: + This implementation requires you to have vmaf-torch installed: https://github.com/alvitrioliks/VMAF-torch. + Install either by cloning the repository and running `pip install .` or with `pip install torchmetrics[video]`. + + """ + def __init__(self): + super().__init__() + self.backend = VMAF_torch() + self.backend.compile() + + def update(self, preds: Tensor, target: Tensor) -> None: + result = self.backend(ref=target, dist=preds) + self.backend.compute_adm_features(ref=target, dist=preds) + self.backend.compute_vif_features(ref=target, dist=preds) + self.backend.compute_motion(ref=target) + + def compute(self) -> Tensor: + pass \ No newline at end of file From fe9cd7adba36b9a45822f8825b795d42ac8b14b6 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 5 Mar 2025 14:56:47 +0100 Subject: [PATCH 03/24] docs --- docs/source/index.rst | 8 ++++++++ docs/source/video/vmaf.rst | 22 ++++++++++++++++++++++ 2 files changed, 30 insertions(+) create mode 100644 docs/source/video/vmaf.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index 46670de00e4..7a8b1f5ff7d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -247,6 +247,14 @@ Or directly from conda text/* +.. toctree:: + :maxdepth: 2 + :name: video + :caption: Video + :glob: + + video/* + .. toctree:: :maxdepth: 2 :name: wrappers diff --git a/docs/source/video/vmaf.rst b/docs/source/video/vmaf.rst new file mode 100644 index 00000000000..647405518ea --- /dev/null +++ b/docs/source/video/vmaf.rst @@ -0,0 +1,22 @@ +.. customcarditem:: + :header: Video Multi-Method Assessment Fusion (VMAF) + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg + :tags: Segmentation + +.. include:: ../links.rst + +########################################### +Video Multi-Method Assessment Fusion (VMAF) +########################################### + +Module Interface +________________ + +.. autoclass:: torchmetrics.video.VideoMultiMethodAssessmentFusion + :noindex: + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.segmentation.video_multi_method_assessment_fusion + :noindex: From 16b0639938aec88f5dcf08147ff1e2a66348efc1 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 5 Mar 2025 14:57:40 +0100 Subject: [PATCH 04/24] test structure --- tests/unittests/video/__init__.py | 13 +++++++++++++ tests/unittests/video/test_vmaf.py | 13 +++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 tests/unittests/video/__init__.py create mode 100644 tests/unittests/video/test_vmaf.py diff --git a/tests/unittests/video/__init__.py b/tests/unittests/video/__init__.py new file mode 100644 index 00000000000..94f1dec4a9f --- /dev/null +++ b/tests/unittests/video/__init__.py @@ -0,0 +1,13 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unittests/video/test_vmaf.py b/tests/unittests/video/test_vmaf.py new file mode 100644 index 00000000000..94f1dec4a9f --- /dev/null +++ b/tests/unittests/video/test_vmaf.py @@ -0,0 +1,13 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. From 06138f6f1a41929ae94bd918f71e9f28dd48af86 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 5 Mar 2025 15:15:26 +0100 Subject: [PATCH 05/24] add starting point of implementation --- requirements/video.txt | 1 + src/torchmetrics/functional/video/vmaf.py | 74 ++++++++++++++++++++++- 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/requirements/video.txt b/requirements/video.txt index bec31c57bfa..e2404237106 100644 --- a/requirements/video.txt +++ b/requirements/video.txt @@ -1 +1,2 @@ vmaf_torch @ git+https://github.com/alvitrioliks/VMAF-torch +einops diff --git a/src/torchmetrics/functional/video/vmaf.py b/src/torchmetrics/functional/video/vmaf.py index 7f2988bb312..633c5ad3c48 100644 --- a/src/torchmetrics/functional/video/vmaf.py +++ b/src/torchmetrics/functional/video/vmaf.py @@ -10,4 +10,76 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. +import torch +import vmaf_torch +from einops import rearrange +from torch import Tensor + + +def calculate_luma(video: Tensor) -> Tensor: + """Calculate the luma component of a video tensor.""" + r = video[:, :, 0, :, :] + g = video[:, :, 1, :, :] + b = video[:, :, 2, :, :] + return (0.299 * r + 0.587 * g + 0.114 * b).unsqueeze(1) * 255 # [0, 1] -> [0, 255] + + +def video_multi_method_assessment_fusion( + preds: Tensor, + target: Tensor, + elementary_features: bool = False, +) -> Tensor: + """Calculates Video Multi-Method Assessment Fusion (VMAF) metric. + + VMAF combined multiple quality assessment features such as detail loss, motion, and contrast using a machine + learning model to predict human perception of video quality more accurately than traditional metrics like PSNR + or SSIM. + + .. note:: + This implementation requires you to have vmaf-torch installed: https://github.com/alvitrioliks/VMAF-torch. + Install either by cloning the repository and running `pip install .` or with `pip install torchmetrics[video]`. + + Args: + preds: Video tensor of shape (batch, channels, frames, height, width). + target: Video tensor of shape (batch, channels, frames, height, width). + elementary_features: If True, returns the elementary features used by VMAF. + + Returns: + If `elementary_features` is False, returns a tensor with the VMAF score for each video in the batch. + If `elementary_features` is True, returns a tensor with the VMAF score and the elementary features used by VMAF. + + Example: + >>> import torch + >>> from torchmetrics.functional.video import video_multi_method_assessment_fusion + >>> preds = torch.rand(2, 3, 10, 32, 32) + >>> target = torch.rand(2, 3, 10, 32, 32) + >>> vmaf = video_multi_method_assessment_fusion(preds, target) + torch.tensor([0.0, 0.0]) + + """ + orig_dtype = preds.dtype + device = preds.device + + preds = (preds.clamp(-1, 1).to(torch.float32) + 1) / 2 # [-1, 1] -> [0, 1] + target = (target.clamp(-1, 1).to(torch.float32) + 1) / 2 # [-1, 1] -> [0, 1] + + preds_luma = calculate_luma(preds) + target_luma = calculate_luma(target) + + vmaf = vmaf_torch.VMAF().to(device) + + score = vmaf( + rearrange(target_luma, "b c t h w -> (b t) c h w"), rearrange(preds_luma, "b c t h w -> (b t) c h w") + ).to(orig_dtype) + + if elementary_features: + adm = vmaf.compute_adm_features( + rearrange(target_luma, "b c t h w -> (b t) c h w"), rearrange(preds_luma, "b c t h w -> (b t) c h w") + ) + vif = vmaf.compute_vif_features( + rearrange(target_luma, "b c t h w -> (b t) c h w"), rearrange(preds_luma, "b c t h w -> (b t) c h w") + ) + motion = vmaf.compute_motion(rearrange(target_luma, "b c t h w -> (b t) c h w")) + score = torch.cat([score, adm, vif, motion], dim=-1) + return score From cf7a8ba7d689e302bbc898be24203dbfaa19b423 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 5 Mar 2025 15:17:26 +0100 Subject: [PATCH 06/24] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ddc8f833bb..44c5bb76cce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Added `VMAF` metric to new video domain ([#2991](https://github.com/Lightning-AI/torchmetrics/pull/2991)) + + - Added `ClusterAccuracy` metric to cluster package ([#2777](https://github.com/Lightning-AI/torchmetrics/pull/2777)) From 6b4397f1e4331ec4e730094eb4d66202d093e88c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 5 Mar 2025 14:20:38 +0000 Subject: [PATCH 07/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/video/__init__.py | 2 +- src/torchmetrics/video/__init__.py | 2 +- src/torchmetrics/video/vmaf.py | 17 ++++++++--------- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/torchmetrics/functional/video/__init__.py b/src/torchmetrics/functional/video/__init__.py index 7f2988bb312..94f1dec4a9f 100644 --- a/src/torchmetrics/functional/video/__init__.py +++ b/src/torchmetrics/functional/video/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/src/torchmetrics/video/__init__.py b/src/torchmetrics/video/__init__.py index 7f2988bb312..94f1dec4a9f 100644 --- a/src/torchmetrics/video/__init__.py +++ b/src/torchmetrics/video/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/src/torchmetrics/video/vmaf.py b/src/torchmetrics/video/vmaf.py index f87003251cc..fd712bb8c58 100644 --- a/src/torchmetrics/video/vmaf.py +++ b/src/torchmetrics/video/vmaf.py @@ -12,19 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. from torch import Tensor -from torchmetrics.metric import Metric from vmaf_torch.vmaf import VMAF as VMAF_torch +from torchmetrics.metric import Metric + + class VMAF(Metric): - """ - - + """.. note:: + This implementation requires you to have vmaf-torch installed: https://github.com/alvitrioliks/VMAF-torch. + Install either by cloning the repository and running `pip install .` or with `pip install torchmetrics[video]`. - .. note:: - This implementation requires you to have vmaf-torch installed: https://github.com/alvitrioliks/VMAF-torch. - Install either by cloning the repository and running `pip install .` or with `pip install torchmetrics[video]`. - """ + def __init__(self): super().__init__() self.backend = VMAF_torch() @@ -37,4 +36,4 @@ def update(self, preds: Tensor, target: Tensor) -> None: self.backend.compute_motion(ref=target) def compute(self) -> Tensor: - pass \ No newline at end of file + pass From 792f67e194d9e8dbc50c67535b8d82901cf9cd3c Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 5 Mar 2025 15:25:51 +0100 Subject: [PATCH 08/24] improve implementations --- src/torchmetrics/functional/video/vmaf.py | 5 ++ src/torchmetrics/utilities/imports.py | 1 + src/torchmetrics/video/vmaf.py | 62 ++++++++++++++++++----- 3 files changed, 54 insertions(+), 14 deletions(-) diff --git a/src/torchmetrics/functional/video/vmaf.py b/src/torchmetrics/functional/video/vmaf.py index 633c5ad3c48..a99fb0790e5 100644 --- a/src/torchmetrics/functional/video/vmaf.py +++ b/src/torchmetrics/functional/video/vmaf.py @@ -16,6 +16,8 @@ from einops import rearrange from torch import Tensor +from torchmetrics.utilities.imports import _TORCH_VMAF_AVAILABLE + def calculate_luma(video: Tensor) -> Tensor: """Calculate the luma component of a video tensor.""" @@ -58,6 +60,9 @@ def video_multi_method_assessment_fusion( torch.tensor([0.0, 0.0]) """ + if not _TORCH_VMAF_AVAILABLE: + raise RuntimeError("vmaf-torch is not installed. Please install with `pip install torchmetrics[video]`.") + orig_dtype = preds.dtype device = preds.device diff --git a/src/torchmetrics/utilities/imports.py b/src/torchmetrics/utilities/imports.py index 4e291dabec1..a9cdf2ace0f 100644 --- a/src/torchmetrics/utilities/imports.py +++ b/src/torchmetrics/utilities/imports.py @@ -63,4 +63,5 @@ _TORCH_LINEAR_ASSIGNMENT_AVAILABLE = RequirementCache("torch_linear_assignment") _AEON_AVAILABLE = RequirementCache("aeon") _PYTDC_AVAILABLE = RequirementCache("pyTDC") +_TORCH_VMAF_AVAILABLE = RequirementCache("torch_vmaf") _LATEX_AVAILABLE: bool = shutil.which("latex") is not None diff --git a/src/torchmetrics/video/vmaf.py b/src/torchmetrics/video/vmaf.py index f87003251cc..94a4b20e380 100644 --- a/src/torchmetrics/video/vmaf.py +++ b/src/torchmetrics/video/vmaf.py @@ -11,30 +11,64 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, List + +import torch +import vmaf_torch from torch import Tensor + from torchmetrics.metric import Metric -from vmaf_torch.vmaf import VMAF as VMAF_torch +from torchmetrics.utilities.imports import _TORCH_VMAF_AVAILABLE -class VMAF(Metric): - """ - - + +class VideoMultiMethodAssessmentFusion(Metric): + """Calculates Video Multi-Method Assessment Fusion (VMAF) metric. + + VMAF combined multiple quality assessment features such as detail loss, motion, and contrast using a machine + learning model to predict human perception of video quality more accurately than traditional metrics like PSNR + or SSIM. .. note:: This implementation requires you to have vmaf-torch installed: https://github.com/alvitrioliks/VMAF-torch. Install either by cloning the repository and running `pip install .` or with `pip install torchmetrics[video]`. - + + Raises: + ValueError: If vmaf-torch is not installed. + """ - def __init__(self): - super().__init__() - self.backend = VMAF_torch() + + vmaf_score: List[Tensor] + adm_features: List[Tensor] + vif_features: List[Tensor] + motion: List[Tensor] + + def __init__(self, elementary_features: bool = False, **kwargs: Any) -> None: + super().__init__(**kwargs) + if not _TORCH_VMAF_AVAILABLE: + raise RuntimeError("vmaf-torch is not installed. Please install with `pip install torchmetrics[video]`.") + + self.backend = vmaf_torch.VMAF().to(self.device) self.backend.compile() + if not isinstance(elementary_features, bool): + raise ValueError("Argument `elementary_features` should be a boolean, but got {elementary_features}.") + self.elementary_features = elementary_features + + self.add_state("vmaf_score", default=[], dist_reduce_fx=None) + if self.elementary_features: + self.add_state("adm_features", default=[], dist_reduce_fx=None) + self.add_state("vif_features", default=[], dist_reduce_fx=None) + self.add_state("motion", default=[], dist_reduce_fx=None) def update(self, preds: Tensor, target: Tensor) -> None: - result = self.backend(ref=target, dist=preds) - self.backend.compute_adm_features(ref=target, dist=preds) - self.backend.compute_vif_features(ref=target, dist=preds) - self.backend.compute_motion(ref=target) + """Calculate VMAF score for each video in the batch.""" + self.vmaf_score.append(self.backend(ref=target, dist=preds)) + if self.elementary_features: + self.adm_features.append(self.backend.compute_adm_features(ref=target, dist=preds)) + self.vif_features.append(self.backend.compute_vif_features(ref=target, dist=preds)) + self.motion.append(self.backend.compute_motion(ref=target)) def compute(self) -> Tensor: - pass \ No newline at end of file + """Return the VMAF score for each video in the batch.""" + if self.elementary_features: + return torch.cat([self.vmaf_score, self.adm_features, self.vif_features, self.motion], dim=1) + return self.vmaf_score From 6b4daca0b72952c3c4a0cc2dd1f04567f2b5df1f Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 31 Mar 2025 15:57:57 +0200 Subject: [PATCH 09/24] fix init and import --- src/torchmetrics/functional/video/__init__.py | 6 ++++++ src/torchmetrics/utilities/imports.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/video/__init__.py b/src/torchmetrics/functional/video/__init__.py index 94f1dec4a9f..62727565d0e 100644 --- a/src/torchmetrics/functional/video/__init__.py +++ b/src/torchmetrics/functional/video/__init__.py @@ -11,3 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from torchmetrics.utilities.imports import _TORCH_VMAF_AVAILABLE + +if _TORCH_VMAF_AVAILABLE: + from torchmetrics.functional.video.vmaf import video_multi_method_assessment_fusion + +__all__ = ["video_multi_method_assessment_fusion"] diff --git a/src/torchmetrics/utilities/imports.py b/src/torchmetrics/utilities/imports.py index a9cdf2ace0f..e2d143e50c4 100644 --- a/src/torchmetrics/utilities/imports.py +++ b/src/torchmetrics/utilities/imports.py @@ -63,5 +63,5 @@ _TORCH_LINEAR_ASSIGNMENT_AVAILABLE = RequirementCache("torch_linear_assignment") _AEON_AVAILABLE = RequirementCache("aeon") _PYTDC_AVAILABLE = RequirementCache("pyTDC") -_TORCH_VMAF_AVAILABLE = RequirementCache("torch_vmaf") +_TORCH_VMAF_AVAILABLE = RequirementCache("vmaf_torch") _LATEX_AVAILABLE: bool = shutil.which("latex") is not None From 461d70dbf8a35bed074c94fcbd8e4559c5d8a62c Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 31 Mar 2025 16:06:20 +0200 Subject: [PATCH 10/24] improve functional implementation --- src/torchmetrics/functional/video/vmaf.py | 66 +++++++++++++++++++---- 1 file changed, 55 insertions(+), 11 deletions(-) diff --git a/src/torchmetrics/functional/video/vmaf.py b/src/torchmetrics/functional/video/vmaf.py index a99fb0790e5..9abc6e6abde 100644 --- a/src/torchmetrics/functional/video/vmaf.py +++ b/src/torchmetrics/functional/video/vmaf.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple, Union + import torch import vmaf_torch from einops import rearrange @@ -31,33 +33,75 @@ def video_multi_method_assessment_fusion( preds: Tensor, target: Tensor, elementary_features: bool = False, -) -> Tensor: +) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor]]: """Calculates Video Multi-Method Assessment Fusion (VMAF) metric. - VMAF combined multiple quality assessment features such as detail loss, motion, and contrast using a machine - learning model to predict human perception of video quality more accurately than traditional metrics like PSNR - or SSIM. + VMAF is a full-reference video quality assessment algorithm that combines multiple quality assessment features + such as detail loss, motion, and contrast using a machine learning model to predict human perception of video + quality more accurately than traditional metrics like PSNR or SSIM. + + The metric works by: + 1. Converting input videos to luma component (grayscale) + 2. Computing multiple elementary features: + - Additive Detail Measure (ADM): Evaluates detail preservation at different scales + - Visual Information Fidelity (VIF): Measures preservation of visual information across frequency bands + - Motion: Quantifies the amount of motion in the video + 3. Combining these features using a trained SVM model to predict quality .. note:: This implementation requires you to have vmaf-torch installed: https://github.com/alvitrioliks/VMAF-torch. Install either by cloning the repository and running `pip install .` or with `pip install torchmetrics[video]`. Args: - preds: Video tensor of shape (batch, channels, frames, height, width). - target: Video tensor of shape (batch, channels, frames, height, width). + preds: Video tensor of shape (batch, channels, frames, height, width). Expected to be in RGB format + with values in range [-1, 1] or [0, 1]. + target: Video tensor of shape (batch, channels, frames, height, width). Expected to be in RGB format + with values in range [-1, 1] or [0, 1]. elementary_features: If True, returns the elementary features used by VMAF. Returns: If `elementary_features` is False, returns a tensor with the VMAF score for each video in the batch. - If `elementary_features` is True, returns a tensor with the VMAF score and the elementary features used by VMAF. + Higher scores indicate better quality, with typical values ranging from 0 to 100. + + If `elementary_features` is True, returns a tuple of four tensors: + - vmaf_score: The main VMAF score tensor + - adm_score: The Additive Detail Measure (ADM) score tensor, which measures the preservation of details + in the video. Shape is (batch * frames, 4) where the 4 values represent different detail scales. + Higher values indicate better detail preservation. + - vif_score: The Visual Information Fidelity (VIF) score tensor, which measures the preservation of + visual information. Shape is (batch * frames, 4) where the 4 values represent different frequency bands. + Higher values indicate better information preservation. + - motion_score: The motion score tensor, which measures the amount of motion in the video. + Shape is (batch * frames,). Higher values indicate more motion. Example: >>> import torch >>> from torchmetrics.functional.video import video_multi_method_assessment_fusion >>> preds = torch.rand(2, 3, 10, 32, 32) >>> target = torch.rand(2, 3, 10, 32, 32) - >>> vmaf = video_multi_method_assessment_fusion(preds, target) - torch.tensor([0.0, 0.0]) + >>> video_multi_method_assessment_fusion(preds, target) + tensor([12.6859, 15.1940, 14.6993, 14.9718, 19.1301, 17.1650]) + >>> vmaf_score, adm_score, vif_score, motion_score = video_multi_method_assessment_fusion( + ... preds, target, elementary_features=True + ... ) + >>> vmaf_score + tensor([12.6859, 15.1940, 14.6993, 14.9718, 19.1301, 17.1650]) + >>> adm_score + tensor([[0.6258, 0.4526, 0.4360, 0.5100], + [0.6117, 0.4558, 0.4478, 0.5543], + [0.6253, 0.4867, 0.4116, 0.4412], + [0.6011, 0.4773, 0.4527, 0.5263], + [0.5830, 0.5209, 0.4050, 0.6781], + [0.6576, 0.5081, 0.4600, 0.6017]]) + >>> vif_score + tensor([[6.8940e-04, 3.5287e-02, 1.2094e-01, 6.7600e-01], + [7.8453e-04, 3.1258e-02, 6.3257e-02, 3.4321e-01], + [1.3337e-03, 2.8432e-02, 6.3114e-02, 4.6726e-01], + [1.8480e-04, 2.3861e-02, 1.5634e-01, 5.5803e-01], + [2.7257e-04, 3.4004e-02, 1.6240e-01, 6.9619e-01], + [1.2596e-03, 2.1799e-02, 1.0870e-01, 2.2582e-01]]) + >>> motion_score + tensor([0.0000, 8.8821, 9.0885, 8.7898, 7.8289, 8.0279]) """ if not _TORCH_VMAF_AVAILABLE: @@ -86,5 +130,5 @@ def video_multi_method_assessment_fusion( rearrange(target_luma, "b c t h w -> (b t) c h w"), rearrange(preds_luma, "b c t h w -> (b t) c h w") ) motion = vmaf.compute_motion(rearrange(target_luma, "b c t h w -> (b t) c h w")) - score = torch.cat([score, adm, vif, motion], dim=-1) - return score + return score.squeeze(), adm, vif, motion.squeeze() + return score.squeeze() From 647bb9cd26b3559d3f34b4e6d6f40efc7aaf504c Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 31 Mar 2025 16:11:38 +0200 Subject: [PATCH 11/24] improve src --- src/torchmetrics/functional/video/__init__.py | 4 +- src/torchmetrics/video/__init__.py | 8 + src/torchmetrics/video/vmaf.py | 155 +++++++++++++++--- 3 files changed, 145 insertions(+), 22 deletions(-) diff --git a/src/torchmetrics/functional/video/__init__.py b/src/torchmetrics/functional/video/__init__.py index 62727565d0e..5cd6d89ebc3 100644 --- a/src/torchmetrics/functional/video/__init__.py +++ b/src/torchmetrics/functional/video/__init__.py @@ -16,4 +16,6 @@ if _TORCH_VMAF_AVAILABLE: from torchmetrics.functional.video.vmaf import video_multi_method_assessment_fusion -__all__ = ["video_multi_method_assessment_fusion"] + __all__ = ["video_multi_method_assessment_fusion"] +else: + __all__ = [] diff --git a/src/torchmetrics/video/__init__.py b/src/torchmetrics/video/__init__.py index 94f1dec4a9f..a7c58cd62bb 100644 --- a/src/torchmetrics/video/__init__.py +++ b/src/torchmetrics/video/__init__.py @@ -11,3 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from torchmetrics.utilities.imports import _TORCH_VMAF_AVAILABLE + +if _TORCH_VMAF_AVAILABLE: + from torchmetrics.video.vmaf import VideoMultiMethodAssessmentFusion + + __all__ = ["VideoMultiMethodAssessmentFusion"] +else: + __all__ = [] diff --git a/src/torchmetrics/video/vmaf.py b/src/torchmetrics/video/vmaf.py index 94a4b20e380..9e863b37dd5 100644 --- a/src/torchmetrics/video/vmaf.py +++ b/src/torchmetrics/video/vmaf.py @@ -11,32 +11,105 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List +from typing import Any, List, Optional, Tuple, Union import torch -import vmaf_torch from torch import Tensor +from torchmetrics.functional.video.vmaf import video_multi_method_assessment_fusion from torchmetrics.metric import Metric from torchmetrics.utilities.imports import _TORCH_VMAF_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE class VideoMultiMethodAssessmentFusion(Metric): """Calculates Video Multi-Method Assessment Fusion (VMAF) metric. - VMAF combined multiple quality assessment features such as detail loss, motion, and contrast using a machine - learning model to predict human perception of video quality more accurately than traditional metrics like PSNR - or SSIM. + VMAF is a full-reference video quality assessment algorithm that combines multiple quality assessment features + such as detail loss, motion, and contrast using a machine learning model to predict human perception of video + quality more accurately than traditional metrics like PSNR or SSIM. + + The metric works by: + 1. Converting input videos to luma component (grayscale) + 2. Computing multiple elementary features: + - Additive Detail Measure (ADM): Evaluates detail preservation at different scales + - Visual Information Fidelity (VIF): Measures preservation of visual information across frequency bands + - Motion: Quantifies the amount of motion in the video + 3. Combining these features using a trained SVM model to predict quality .. note:: This implementation requires you to have vmaf-torch installed: https://github.com/alvitrioliks/VMAF-torch. Install either by cloning the repository and running `pip install .` or with `pip install torchmetrics[video]`. + As input to ``forward`` and ``update`` the metric accepts the following input + + - ``preds`` (:class:`~torch.Tensor`): Video tensor of shape ``(batch, channels, frames, height, width)``. + Expected to be in RGB format with values in range [-1, 1] or [0, 1]. + - ``target`` (:class:`~torch.Tensor`): Video tensor of shape ``(batch, channels, frames, height, width)``. + Expected to be in RGB format with values in range [-1, 1] or [0, 1]. + + As output of `forward` and `compute` the metric returns the following output + + If `elementary_features` is False: + - ``vmaf`` (:class:`~torch.Tensor`): A tensor with the VMAF score for each video in the batch. + Higher scores indicate better quality, with typical values ranging from 0 to 100. + + If `elementary_features` is True: + - ``vmaf_score`` (:class:`~torch.Tensor`): The main VMAF score tensor + - ``adm_score`` (:class:`~torch.Tensor`): The Additive Detail Measure (ADM) score tensor, which measures + the preservation of details in the video. Shape is (batch * frames, 4) where the 4 values represent + different detail scales. Higher values indicate better detail preservation. + - ``vif_score`` (:class:`~torch.Tensor`): The Visual Information Fidelity (VIF) score tensor, which measures + the preservation of visual information. Shape is (batch * frames, 4) where the 4 values represent different + frequency bands. Higher values indicate better information preservation. + - ``motion_score`` (:class:`~torch.Tensor`): The motion score tensor, which measures the amount of motion + in the video. Shape is (batch * frames,). Higher values indicate more motion. + + Args: + elementary_features: If True, returns the elementary features used by VMAF. + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + Raises: - ValueError: If vmaf-torch is not installed. + RuntimeError: If vmaf-torch is not installed. + ValueError: If `elementary_features` is not a boolean. + + Example: + >>> from torch import rand + >>> from torchmetrics.video import VideoMultiMethodAssessmentFusion + >>> preds = rand(2, 3, 10, 32, 32) + >>> target = rand(2, 3, 10, 32, 32) + >>> vmaf = VideoMultiMethodAssessmentFusion() + >>> vmaf(preds, target) + tensor([12.6859, 15.1940, 14.6993, 14.9718, 19.1301, 17.1650]) + >>> vmaf = VideoMultiMethodAssessmentFusion(elementary_features=True) + >>> vmaf_score, adm_score, vif_score, motion_score = vmaf(preds, target) + >>> vmaf_score + tensor([12.6859, 15.1940, 14.6993, 14.9718, 19.1301, 17.1650]) + >>> adm_score + tensor([[0.6258, 0.4526, 0.4360, 0.5100], + [0.6117, 0.4558, 0.4478, 0.5543], + [0.6253, 0.4867, 0.4116, 0.4412], + [0.6011, 0.4773, 0.4527, 0.5263], + [0.5830, 0.5209, 0.4050, 0.6781], + [0.6576, 0.5081, 0.4600, 0.6017]]) + >>> vif_score + tensor([[6.8940e-04, 3.5287e-02, 1.2094e-01, 6.7600e-01], + [7.8453e-04, 3.1258e-02, 6.3257e-02, 3.4321e-01], + [1.3337e-03, 2.8432e-02, 6.3114e-02, 4.6726e-01], + [1.8480e-04, 2.3861e-02, 1.5634e-01, 5.5803e-01], + [2.7257e-04, 3.4004e-02, 1.6240e-01, 6.9619e-01], + [1.2596e-03, 2.1799e-02, 1.0870e-01, 2.2582e-01]]) + >>> motion_score + tensor([0.0000, 8.8821, 9.0885, 8.7898, 7.8289, 8.0279]) """ + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = False + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 100.0 # Updated to match VMAF score range + vmaf_score: List[Tensor] adm_features: List[Tensor] vif_features: List[Tensor] @@ -47,28 +120,68 @@ def __init__(self, elementary_features: bool = False, **kwargs: Any) -> None: if not _TORCH_VMAF_AVAILABLE: raise RuntimeError("vmaf-torch is not installed. Please install with `pip install torchmetrics[video]`.") - self.backend = vmaf_torch.VMAF().to(self.device) - self.backend.compile() if not isinstance(elementary_features, bool): raise ValueError("Argument `elementary_features` should be a boolean, but got {elementary_features}.") self.elementary_features = elementary_features - self.add_state("vmaf_score", default=[], dist_reduce_fx=None) + self.add_state("vmaf_score", default=[], dist_reduce_fx="cat") if self.elementary_features: - self.add_state("adm_features", default=[], dist_reduce_fx=None) - self.add_state("vif_features", default=[], dist_reduce_fx=None) - self.add_state("motion", default=[], dist_reduce_fx=None) + self.add_state("adm_features", default=[], dist_reduce_fx="cat") + self.add_state("vif_features", default=[], dist_reduce_fx="cat") + self.add_state("motion", default=[], dist_reduce_fx="cat") def update(self, preds: Tensor, target: Tensor) -> None: - """Calculate VMAF score for each video in the batch.""" - self.vmaf_score.append(self.backend(ref=target, dist=preds)) + """Update state with predictions and targets.""" + score = video_multi_method_assessment_fusion(preds, target, self.elementary_features) if self.elementary_features: - self.adm_features.append(self.backend.compute_adm_features(ref=target, dist=preds)) - self.vif_features.append(self.backend.compute_vif_features(ref=target, dist=preds)) - self.motion.append(self.backend.compute_motion(ref=target)) + self.vmaf_score.append(score[0]) + self.adm_features.append(score[1]) + self.vif_features.append(score[2]) + self.motion.append(score[3]) + else: + self.vmaf_score.append(score) - def compute(self) -> Tensor: - """Return the VMAF score for each video in the batch.""" + def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor]]: + """Compute final VMAF score.""" if self.elementary_features: - return torch.cat([self.vmaf_score, self.adm_features, self.vif_features, self.motion], dim=1) - return self.vmaf_score + return ( + torch.cat(self.vmaf_score, dim=0), + torch.cat(self.adm_features, dim=0), + torch.cat(self.vif_features, dim=0), + torch.cat(self.motion, dim=0), + ) + return torch.cat(self.vmaf_score, dim=0) + + def plot( + self, val: Optional[Union[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward()` or `metric.compute()`, or the results + from multiple calls of `metric.forward()` or `metric.compute()`. If no value is provided, will + automatically call `metric.compute()` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + :caption: Example of plotting VMAF scores and elementary features. + + >>> # Example plotting a single value + >>> from torch import rand + >>> from torchmetrics.video import VideoMultiMethodAssessmentFusion + >>> metric = VideoMultiMethodAssessmentFusion(elementary_features=True) + >>> preds = rand(2, 3, 10, 32, 32) + >>> target = rand(2, 3, 10, 32, 32) + >>> metric.update(preds, target) + >>> fig_, ax_ = metric.plot() + + """ + return super().plot(val=val, ax=ax) From ad2a46af73fd5a939ba0c0ad2467fc854eb848c1 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 31 Mar 2025 16:16:30 +0200 Subject: [PATCH 12/24] add testing --- tests/unittests/video/test_vmaf.py | 161 +++++++++++++++++++++++++++++ 1 file changed, 161 insertions(+) diff --git a/tests/unittests/video/test_vmaf.py b/tests/unittests/video/test_vmaf.py index 94f1dec4a9f..5efb3e35a56 100644 --- a/tests/unittests/video/test_vmaf.py +++ b/tests/unittests/video/test_vmaf.py @@ -11,3 +11,164 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import pytest +import torch +import vmaf_torch +from einops import rearrange + +from torchmetrics.functional.video.vmaf import video_multi_method_assessment_fusion +from torchmetrics.utilities.imports import _TORCH_VMAF_AVAILABLE +from torchmetrics.video import VideoMultiMethodAssessmentFusion +from unittests import BATCH_SIZE, NUM_BATCHES, _Input +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester + +seed_all(42) + + +def _reference_vmaf(preds, target, elementary_features=False): + """Reference implementation of VMAF metric.""" + device = preds.device + orig_dtype = preds.dtype + + # Convert to float32 for processing + preds = (preds.clamp(-1, 1).to(torch.float32) + 1) / 2 # [-1, 1] -> [0, 1] + target = (target.clamp(-1, 1).to(torch.float32) + 1) / 2 # [-1, 1] -> [0, 1] + + # Calculate luma component + def calculate_luma(video): + r = video[:, :, 0, :, :] + g = video[:, :, 1, :, :] + b = video[:, :, 2, :, :] + return (0.299 * r + 0.587 * g + 0.114 * b).unsqueeze(1) * 255 # [0, 1] -> [0, 255] + + preds_luma = calculate_luma(preds) + target_luma = calculate_luma(target) + + vmaf = vmaf_torch.VMAF().to(device) + + score = vmaf(rearrange(target_luma, "b c t h w -> (b t) c h w"), rearrange(preds_luma, "b c t h w -> (b t) c h w")) + + if elementary_features: + adm = vmaf.compute_adm_features( + rearrange(target_luma, "b c t h w -> (b t) c h w"), rearrange(preds_luma, "b c t h w -> (b t) c h w") + ) + vif = vmaf.compute_vif_features( + rearrange(target_luma, "b c t h w -> (b t) c h w"), rearrange(preds_luma, "b c t h w -> (b t) c h w") + ) + motion = vmaf.compute_motion(rearrange(target_luma, "b c t h w -> (b t) c h w")) + return score.squeeze().to(orig_dtype), adm.to(orig_dtype), vif.to(orig_dtype), motion.squeeze().to(orig_dtype) + return score.squeeze().to(orig_dtype) + + +# Define inputs +_inputs = [] +for size in [32, 64]: + preds = torch.rand(NUM_BATCHES, BATCH_SIZE, 3, 10, size, size) + target = torch.rand(NUM_BATCHES, BATCH_SIZE, 3, 10, size, size) + _inputs.append(_Input(preds=preds, target=target)) + + +@pytest.mark.skipif(not _TORCH_VMAF_AVAILABLE, reason="test requires vmaf-torch") +@pytest.mark.parametrize("preds, target", [(i.preds, i.target) for i in _inputs]) +class TestVMAF(MetricTester): + """Test class for `VideoMultiMethodAssessmentFusion` metric.""" + + atol = 1e-6 + + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + def test_vmaf(self, preds, target, ddp): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + preds, + target, + metric_functional=video_multi_method_assessment_fusion, + reference_metric=_reference_vmaf, + ) + + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + def test_vmaf_class(self, preds, target, ddp): + """Test class implementation of metric.""" + self.run_class_metric_test( + ddp, + preds, + target, + VideoMultiMethodAssessmentFusion, + _reference_vmaf, + ) + + def test_vmaf_elementary_features(self, preds, target): + """Test that elementary features are returned when requested.""" + # Test functional implementation + score = video_multi_method_assessment_fusion(preds, target, elementary_features=True) + assert isinstance(score, tuple) + assert len(score) == 4 # VMAF score + ADM + VIF + motion + assert score[0].shape == (NUM_BATCHES * BATCH_SIZE * 10,) # VMAF score shape + assert score[1].shape == (NUM_BATCHES * BATCH_SIZE * 10, 4) # ADM shape + assert score[2].shape == (NUM_BATCHES * BATCH_SIZE * 10, 4) # VIF shape + assert score[3].shape == (NUM_BATCHES * BATCH_SIZE * 10,) # Motion shape + + # Test class implementation + metric = VideoMultiMethodAssessmentFusion(elementary_features=True) + metric.update(preds, target) + score = metric.compute() + assert isinstance(score, tuple) + assert len(score) == 4 + assert score[0].shape == (NUM_BATCHES * BATCH_SIZE * 10,) + assert score[1].shape == (NUM_BATCHES * BATCH_SIZE * 10, 4) + assert score[2].shape == (NUM_BATCHES * BATCH_SIZE * 10, 4) + assert score[3].shape == (NUM_BATCHES * BATCH_SIZE * 10,) + + def test_vmaf_half_cpu(self, preds, target): + """Test for half precision on CPU.""" + # Convert inputs to half precision + preds = preds.to(torch.float16) + target = target.to(torch.float16) + + self.run_precision_test_cpu( + preds, + target, + video_multi_method_assessment_fusion, + _reference_vmaf, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + def test_vmaf_half_gpu(self, preds, target): + """Test for half precision on GPU.""" + # Convert inputs to half precision + preds = preds.to(torch.float16) + target = target.to(torch.float16) + + self.run_precision_test_gpu( + preds, + target, + video_multi_method_assessment_fusion, + _reference_vmaf, + ) + + def test_vmaf_plot(self, preds, target): + """Test the plot method of the metric.""" + # Test basic VMAF plotting + metric = VideoMultiMethodAssessmentFusion() + metric.update(preds, target) + fig, ax = metric.plot() + assert fig is not None + assert ax is not None + + # Test plotting with elementary features + metric = VideoMultiMethodAssessmentFusion(elementary_features=True) + metric.update(preds, target) + fig, ax = metric.plot() + assert fig is not None + assert ax is not None + + +@pytest.mark.skipif(_TORCH_VMAF_AVAILABLE, reason="test requires vmaf-torch") +def test_vmaf_raises_error(): + """Test that appropriate error is raised when vmaf-torch is not installed.""" + with pytest.raises(RuntimeError, match="vmaf-torch is not installed"): + video_multi_method_assessment_fusion(torch.rand(1, 3, 10, 32, 32), torch.rand(1, 3, 10, 32, 32)) + + with pytest.raises(RuntimeError, match="vmaf-torch is not installed"): + VideoMultiMethodAssessmentFusion() From 43322c2e486084087649864ed84a492c52b3930a Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 31 Mar 2025 16:24:34 +0200 Subject: [PATCH 13/24] improve tests --- tests/unittests/video/test_vmaf.py | 71 ++++++++++-------------------- 1 file changed, 23 insertions(+), 48 deletions(-) diff --git a/tests/unittests/video/test_vmaf.py b/tests/unittests/video/test_vmaf.py index 5efb3e35a56..d58afd24f42 100644 --- a/tests/unittests/video/test_vmaf.py +++ b/tests/unittests/video/test_vmaf.py @@ -65,8 +65,8 @@ def calculate_luma(video): # Define inputs _inputs = [] for size in [32, 64]: - preds = torch.rand(NUM_BATCHES, BATCH_SIZE, 3, 10, size, size) - target = torch.rand(NUM_BATCHES, BATCH_SIZE, 3, 10, size, size) + preds = torch.rand(2, 4, 3, 10, size, size) + target = torch.rand(2, 4, 3, 10, size, size) _inputs.append(_Input(preds=preds, target=target)) @@ -75,29 +75,28 @@ def calculate_luma(video): class TestVMAF(MetricTester): """Test class for `VideoMultiMethodAssessmentFusion` metric.""" - atol = 1e-6 + atol = 1e-3 @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) def test_vmaf(self, preds, target, ddp): + """Test class implementation of metric.""" + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=VideoMultiMethodAssessmentFusion, + reference_metric=_reference_vmaf, + ) + + def test_vmaf_functional(self, preds, target): """Test functional implementation of metric.""" self.run_functional_metric_test( - preds, - target, + preds=preds, + target=target, metric_functional=video_multi_method_assessment_fusion, reference_metric=_reference_vmaf, ) - @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_vmaf_class(self, preds, target, ddp): - """Test class implementation of metric.""" - self.run_class_metric_test( - ddp, - preds, - target, - VideoMultiMethodAssessmentFusion, - _reference_vmaf, - ) - def test_vmaf_elementary_features(self, preds, target): """Test that elementary features are returned when requested.""" # Test functional implementation @@ -122,47 +121,23 @@ def test_vmaf_elementary_features(self, preds, target): def test_vmaf_half_cpu(self, preds, target): """Test for half precision on CPU.""" - # Convert inputs to half precision - preds = preds.to(torch.float16) - target = target.to(torch.float16) - self.run_precision_test_cpu( - preds, - target, - video_multi_method_assessment_fusion, - _reference_vmaf, + preds=preds, + target=target, + metric_module=VideoMultiMethodAssessmentFusion, + metric_functional=video_multi_method_assessment_fusion, ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") def test_vmaf_half_gpu(self, preds, target): """Test for half precision on GPU.""" - # Convert inputs to half precision - preds = preds.to(torch.float16) - target = target.to(torch.float16) - self.run_precision_test_gpu( - preds, - target, - video_multi_method_assessment_fusion, - _reference_vmaf, + preds=preds, + target=target, + metric_module=VideoMultiMethodAssessmentFusion, + metric_functional=video_multi_method_assessment_fusion, ) - def test_vmaf_plot(self, preds, target): - """Test the plot method of the metric.""" - # Test basic VMAF plotting - metric = VideoMultiMethodAssessmentFusion() - metric.update(preds, target) - fig, ax = metric.plot() - assert fig is not None - assert ax is not None - - # Test plotting with elementary features - metric = VideoMultiMethodAssessmentFusion(elementary_features=True) - metric.update(preds, target) - fig, ax = metric.plot() - assert fig is not None - assert ax is not None - @pytest.mark.skipif(_TORCH_VMAF_AVAILABLE, reason="test requires vmaf-torch") def test_vmaf_raises_error(): From 1072c611031fcf5c8be0d38b2a40e9adb6c1326e Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 31 Mar 2025 16:30:17 +0200 Subject: [PATCH 14/24] tests --- tests/unittests/video/test_vmaf.py | 31 ++++++++++-------------------- 1 file changed, 10 insertions(+), 21 deletions(-) diff --git a/tests/unittests/video/test_vmaf.py b/tests/unittests/video/test_vmaf.py index d58afd24f42..1fa1ecfca50 100644 --- a/tests/unittests/video/test_vmaf.py +++ b/tests/unittests/video/test_vmaf.py @@ -20,7 +20,7 @@ from torchmetrics.functional.video.vmaf import video_multi_method_assessment_fusion from torchmetrics.utilities.imports import _TORCH_VMAF_AVAILABLE from torchmetrics.video import VideoMultiMethodAssessmentFusion -from unittests import BATCH_SIZE, NUM_BATCHES, _Input +from unittests import _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester @@ -63,10 +63,11 @@ def calculate_luma(video): # Define inputs +NUM_BATCHES, BATCH_SIZE = 2, 4 _inputs = [] for size in [32, 64]: - preds = torch.rand(2, 4, 3, 10, size, size) - target = torch.rand(2, 4, 3, 10, size, size) + preds = torch.rand(NUM_BATCHES, BATCH_SIZE, 3, 10, size, size) + target = torch.rand(NUM_BATCHES, BATCH_SIZE, 3, 10, size, size) _inputs.append(_Input(preds=preds, target=target)) @@ -100,24 +101,14 @@ def test_vmaf_functional(self, preds, target): def test_vmaf_elementary_features(self, preds, target): """Test that elementary features are returned when requested.""" # Test functional implementation - score = video_multi_method_assessment_fusion(preds, target, elementary_features=True) + score = video_multi_method_assessment_fusion(preds[0], target[0], elementary_features=True) + breakpoint() assert isinstance(score, tuple) assert len(score) == 4 # VMAF score + ADM + VIF + motion - assert score[0].shape == (NUM_BATCHES * BATCH_SIZE * 10,) # VMAF score shape - assert score[1].shape == (NUM_BATCHES * BATCH_SIZE * 10, 4) # ADM shape - assert score[2].shape == (NUM_BATCHES * BATCH_SIZE * 10, 4) # VIF shape - assert score[3].shape == (NUM_BATCHES * BATCH_SIZE * 10,) # Motion shape - - # Test class implementation - metric = VideoMultiMethodAssessmentFusion(elementary_features=True) - metric.update(preds, target) - score = metric.compute() - assert isinstance(score, tuple) - assert len(score) == 4 - assert score[0].shape == (NUM_BATCHES * BATCH_SIZE * 10,) - assert score[1].shape == (NUM_BATCHES * BATCH_SIZE * 10, 4) - assert score[2].shape == (NUM_BATCHES * BATCH_SIZE * 10, 4) - assert score[3].shape == (NUM_BATCHES * BATCH_SIZE * 10,) + assert score[0].shape == (BATCH_SIZE,) # VMAF score shape + assert score[1].shape == (BATCH_SIZE, 4) # ADM shape + assert score[2].shape == (BATCH_SIZE, 4) # VIF shape + assert score[3].shape == (BATCH_SIZE,) # Motion shape def test_vmaf_half_cpu(self, preds, target): """Test for half precision on CPU.""" @@ -145,5 +136,3 @@ def test_vmaf_raises_error(): with pytest.raises(RuntimeError, match="vmaf-torch is not installed"): video_multi_method_assessment_fusion(torch.rand(1, 3, 10, 32, 32), torch.rand(1, 3, 10, 32, 32)) - with pytest.raises(RuntimeError, match="vmaf-torch is not installed"): - VideoMultiMethodAssessmentFusion() From dcfa02704abc4c2306920662b309e4ea706660e0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 1 Apr 2025 15:47:50 +0000 Subject: [PATCH 15/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/video/test_vmaf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unittests/video/test_vmaf.py b/tests/unittests/video/test_vmaf.py index 1fa1ecfca50..03cb4117cf2 100644 --- a/tests/unittests/video/test_vmaf.py +++ b/tests/unittests/video/test_vmaf.py @@ -135,4 +135,3 @@ def test_vmaf_raises_error(): """Test that appropriate error is raised when vmaf-torch is not installed.""" with pytest.raises(RuntimeError, match="vmaf-torch is not installed"): video_multi_method_assessment_fusion(torch.rand(1, 3, 10, 32, 32), torch.rand(1, 3, 10, 32, 32)) - From 6c11a89f8c6059edd2755259a11c603cfab7815d Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 9 Apr 2025 08:04:11 +0200 Subject: [PATCH 16/24] smaller changes --- src/torchmetrics/functional/video/vmaf.py | 47 ++++++++++------------- src/torchmetrics/video/vmaf.py | 16 ++++---- 2 files changed, 28 insertions(+), 35 deletions(-) diff --git a/src/torchmetrics/functional/video/vmaf.py b/src/torchmetrics/functional/video/vmaf.py index 9abc6e6abde..272e8b0be0e 100644 --- a/src/torchmetrics/functional/video/vmaf.py +++ b/src/torchmetrics/functional/video/vmaf.py @@ -54,9 +54,9 @@ def video_multi_method_assessment_fusion( Args: preds: Video tensor of shape (batch, channels, frames, height, width). Expected to be in RGB format - with values in range [-1, 1] or [0, 1]. + with values in range [0, 1]. target: Video tensor of shape (batch, channels, frames, height, width). Expected to be in RGB format - with values in range [-1, 1] or [0, 1]. + with values in range [0, 1]. elementary_features: If True, returns the elementary features used by VMAF. Returns: @@ -77,48 +77,43 @@ def video_multi_method_assessment_fusion( Example: >>> import torch >>> from torchmetrics.functional.video import video_multi_method_assessment_fusion - >>> preds = torch.rand(2, 3, 10, 32, 32) - >>> target = torch.rand(2, 3, 10, 32, 32) + >>> preds = torch.rand(2, 3, 10, 32, 32) # 2 videos, 3 channels, 10 frames, 32x32 resolution + >>> target = torch.rand(2, 3, 10, 32, 32) # 2 videos, 3 channels, 10 frames, 32x32 resolution >>> video_multi_method_assessment_fusion(preds, target) - tensor([12.6859, 15.1940, 14.6993, 14.9718, 19.1301, 17.1650]) + tensor([ 3.9553, 15.2808, 15.0131, 13.7132, 14.0283, 16.9560]) >>> vmaf_score, adm_score, vif_score, motion_score = video_multi_method_assessment_fusion( ... preds, target, elementary_features=True ... ) >>> vmaf_score - tensor([12.6859, 15.1940, 14.6993, 14.9718, 19.1301, 17.1650]) + tensor([ 3.9553, 15.2808, 15.0131, 13.7132, 14.0283, 16.9560]) >>> adm_score - tensor([[0.6258, 0.4526, 0.4360, 0.5100], - [0.6117, 0.4558, 0.4478, 0.5543], - [0.6253, 0.4867, 0.4116, 0.4412], - [0.6011, 0.4773, 0.4527, 0.5263], - [0.5830, 0.5209, 0.4050, 0.6781], - [0.6576, 0.5081, 0.4600, 0.6017]]) + tensor([[0.5128, 0.3583, 0.3275, 0.3798], + [0.5034, 0.3581, 0.3421, 0.4753], + [0.5161, 0.3987, 0.3176, 0.2830], + [0.4823, 0.3802, 0.3569, 0.4263], + [0.4627, 0.4267, 0.2862, 0.5625], + [0.5576, 0.4112, 0.3703, 0.5333]]) >>> vif_score - tensor([[6.8940e-04, 3.5287e-02, 1.2094e-01, 6.7600e-01], - [7.8453e-04, 3.1258e-02, 6.3257e-02, 3.4321e-01], - [1.3337e-03, 2.8432e-02, 6.3114e-02, 4.6726e-01], - [1.8480e-04, 2.3861e-02, 1.5634e-01, 5.5803e-01], - [2.7257e-04, 3.4004e-02, 1.6240e-01, 6.9619e-01], - [1.2596e-03, 2.1799e-02, 1.0870e-01, 2.2582e-01]]) + tensor([[5.5589e-04, 2.3668e-02, 5.9746e-02, 1.8287e-01], + [6.3305e-04, 2.1592e-02, 3.8126e-02, 1.1630e-01], + [1.0766e-03, 1.9478e-02, 3.5908e-02, 5.0494e-02], + [1.4880e-04, 1.6239e-02, 2.6883e-02, 1.5944e-01], + [2.1966e-04, 2.2355e-02, 6.6175e-02, 5.8169e-02], + [1.0138e-03, 1.5265e-02, 5.5632e-02, 1.2230e-01]]) >>> motion_score - tensor([0.0000, 8.8821, 9.0885, 8.7898, 7.8289, 8.0279]) + tensor([ 0.0000, 17.7642, 18.1769, 17.5795, 15.6578, 16.0557]) """ if not _TORCH_VMAF_AVAILABLE: raise RuntimeError("vmaf-torch is not installed. Please install with `pip install torchmetrics[video]`.") - orig_dtype = preds.dtype - device = preds.device - - preds = (preds.clamp(-1, 1).to(torch.float32) + 1) / 2 # [-1, 1] -> [0, 1] - target = (target.clamp(-1, 1).to(torch.float32) + 1) / 2 # [-1, 1] -> [0, 1] - + orig_dtype, device = preds.dtype, preds.device preds_luma = calculate_luma(preds) target_luma = calculate_luma(target) vmaf = vmaf_torch.VMAF().to(device) - score = vmaf( + score: Tensor = vmaf( rearrange(target_luma, "b c t h w -> (b t) c h w"), rearrange(preds_luma, "b c t h w -> (b t) c h w") ).to(orig_dtype) diff --git a/src/torchmetrics/video/vmaf.py b/src/torchmetrics/video/vmaf.py index 9e863b37dd5..d4cac307f8c 100644 --- a/src/torchmetrics/video/vmaf.py +++ b/src/torchmetrics/video/vmaf.py @@ -44,17 +44,15 @@ class VideoMultiMethodAssessmentFusion(Metric): As input to ``forward`` and ``update`` the metric accepts the following input - ``preds`` (:class:`~torch.Tensor`): Video tensor of shape ``(batch, channels, frames, height, width)``. - Expected to be in RGB format with values in range [-1, 1] or [0, 1]. + Expected to be in RGB format with values in range [0, 1]. - ``target`` (:class:`~torch.Tensor`): Video tensor of shape ``(batch, channels, frames, height, width)``. - Expected to be in RGB format with values in range [-1, 1] or [0, 1]. + Expected to be in RGB format with values in range [0, 1]. As output of `forward` and `compute` the metric returns the following output - If `elementary_features` is False: - - ``vmaf`` (:class:`~torch.Tensor`): A tensor with the VMAF score for each video in the batch. - Higher scores indicate better quality, with typical values ranging from 0 to 100. - - If `elementary_features` is True: + - ``vmaf`` (:class:`~torch.Tensor`): If `elementary_features` is set to False, a single tensor with the VMAF score + for each video in the batch. If `elementary_features` is True, a tuple of tensors is returned: + - ``vmaf_score`` (:class:`~torch.Tensor`): The main VMAF score tensor - ``adm_score`` (:class:`~torch.Tensor`): The Additive Detail Measure (ADM) score tensor, which measures the preservation of details in the video. Shape is (batch * frames, 4) where the 4 values represent @@ -76,8 +74,8 @@ class VideoMultiMethodAssessmentFusion(Metric): Example: >>> from torch import rand >>> from torchmetrics.video import VideoMultiMethodAssessmentFusion - >>> preds = rand(2, 3, 10, 32, 32) - >>> target = rand(2, 3, 10, 32, 32) + >>> preds = rand(2, 3, 10, 32, 32) # 2 videos, 3 channels, 10 frames, 32x32 resolution + >>> target = rand(2, 3, 10, 32, 32) # 2 videos, 3 channels, 10 frames, 32x32 resolution >>> vmaf = VideoMultiMethodAssessmentFusion() >>> vmaf(preds, target) tensor([12.6859, 15.1940, 14.6993, 14.9718, 19.1301, 17.1650]) From a937f26ea8330acb521a9ac2a0add1252bbe0c19 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 16 Apr 2025 07:51:26 +0000 Subject: [PATCH 17/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/video/vmaf.py | 1 - src/torchmetrics/video/vmaf.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/functional/video/vmaf.py b/src/torchmetrics/functional/video/vmaf.py index 272e8b0be0e..b3888794629 100644 --- a/src/torchmetrics/functional/video/vmaf.py +++ b/src/torchmetrics/functional/video/vmaf.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import Tuple, Union -import torch import vmaf_torch from einops import rearrange from torch import Tensor diff --git a/src/torchmetrics/video/vmaf.py b/src/torchmetrics/video/vmaf.py index d4cac307f8c..16247033201 100644 --- a/src/torchmetrics/video/vmaf.py +++ b/src/torchmetrics/video/vmaf.py @@ -50,9 +50,9 @@ class VideoMultiMethodAssessmentFusion(Metric): As output of `forward` and `compute` the metric returns the following output - - ``vmaf`` (:class:`~torch.Tensor`): If `elementary_features` is set to False, a single tensor with the VMAF score + - ``vmaf`` (:class:`~torch.Tensor`): If `elementary_features` is set to False, a single tensor with the VMAF score for each video in the batch. If `elementary_features` is True, a tuple of tensors is returned: - + - ``vmaf_score`` (:class:`~torch.Tensor`): The main VMAF score tensor - ``adm_score`` (:class:`~torch.Tensor`): The Additive Detail Measure (ADM) score tensor, which measures the preservation of details in the video. Shape is (batch * frames, 4) where the 4 values represent From 5733e823c8623146d38539b77201a3d482bca6e6 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 29 Apr 2025 10:21:53 +0200 Subject: [PATCH 18/24] fix functional implementation --- CHANGELOG.md | 1 + src/torchmetrics/functional/video/vmaf.py | 113 ++++++++++++++-------- 2 files changed, 75 insertions(+), 39 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7cecd9f5d86..e7d8dd1d726 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added support single `str` input for functional interface of `bert_score` ([#3056](https://github.com/Lightning-AI/torchmetrics/pull/3056)) + ### Changed - diff --git a/src/torchmetrics/functional/video/vmaf.py b/src/torchmetrics/functional/video/vmaf.py index b3888794629..628d671b43d 100644 --- a/src/torchmetrics/functional/video/vmaf.py +++ b/src/torchmetrics/functional/video/vmaf.py @@ -22,9 +22,9 @@ def calculate_luma(video: Tensor) -> Tensor: """Calculate the luma component of a video tensor.""" - r = video[:, :, 0, :, :] - g = video[:, :, 1, :, :] - b = video[:, :, 2, :, :] + r = video[:, 0, :, :, :] + g = video[:, 1, :, :, :] + b = video[:, 2, :, :, :] return (0.299 * r + 0.587 * g + 0.114 * b).unsqueeze(1) * 255 # [0, 1] -> [0, 255] @@ -59,19 +59,19 @@ def video_multi_method_assessment_fusion( elementary_features: If True, returns the elementary features used by VMAF. Returns: - If `elementary_features` is False, returns a tensor with the VMAF score for each video in the batch. - Higher scores indicate better quality, with typical values ranging from 0 to 100. + If `elementary_features` is False, returns a tensor with shape (batch, frame) of VMAF score for each frame in + each video. Higher scores indicate better quality, with typical values ranging from 0 to 100. If `elementary_features` is True, returns a tuple of four tensors: - - vmaf_score: The main VMAF score tensor + - vmaf_score: The main VMAF score tensor of shape (batch, frames) - adm_score: The Additive Detail Measure (ADM) score tensor, which measures the preservation of details - in the video. Shape is (batch * frames, 4) where the 4 values represent different detail scales. + in the video. Shape is (batch, frames, 4) where the 4 values represent different detail scales. Higher values indicate better detail preservation. - vif_score: The Visual Information Fidelity (VIF) score tensor, which measures the preservation of - visual information. Shape is (batch * frames, 4) where the 4 values represent different frequency bands. + visual information. Shape is (batch, frames, 4) where the 4 values represent different frequency bands. Higher values indicate better information preservation. - motion_score: The motion score tensor, which measures the amount of motion in the video. - Shape is (batch * frames,). Higher values indicate more motion. + Shape is (batch, frames). Higher values indicate more motion. Example: >>> import torch @@ -79,33 +79,64 @@ def video_multi_method_assessment_fusion( >>> preds = torch.rand(2, 3, 10, 32, 32) # 2 videos, 3 channels, 10 frames, 32x32 resolution >>> target = torch.rand(2, 3, 10, 32, 32) # 2 videos, 3 channels, 10 frames, 32x32 resolution >>> video_multi_method_assessment_fusion(preds, target) - tensor([ 3.9553, 15.2808, 15.0131, 13.7132, 14.0283, 16.9560]) + tensor([[ 7.0141, 17.4276, 15.1429, 14.9831, 19.3378, 12.5638, 13.9680, 13.4165, 17.9314, 15.6604], + [13.9790, 11.1951, 15.3990, 13.5877, 15.1370, 18.4508, 17.5596, 18.6859, 12.9309, 15.1975]]) >>> vmaf_score, adm_score, vif_score, motion_score = video_multi_method_assessment_fusion( ... preds, target, elementary_features=True ... ) >>> vmaf_score - tensor([ 3.9553, 15.2808, 15.0131, 13.7132, 14.0283, 16.9560]) - >>> adm_score - tensor([[0.5128, 0.3583, 0.3275, 0.3798], - [0.5034, 0.3581, 0.3421, 0.4753], - [0.5161, 0.3987, 0.3176, 0.2830], - [0.4823, 0.3802, 0.3569, 0.4263], - [0.4627, 0.4267, 0.2862, 0.5625], - [0.5576, 0.4112, 0.3703, 0.5333]]) - >>> vif_score - tensor([[5.5589e-04, 2.3668e-02, 5.9746e-02, 1.8287e-01], - [6.3305e-04, 2.1592e-02, 3.8126e-02, 1.1630e-01], - [1.0766e-03, 1.9478e-02, 3.5908e-02, 5.0494e-02], - [1.4880e-04, 1.6239e-02, 2.6883e-02, 1.5944e-01], - [2.1966e-04, 2.2355e-02, 6.6175e-02, 5.8169e-02], - [1.0138e-03, 1.5265e-02, 5.5632e-02, 1.2230e-01]]) + tensor([[ 7.0141, 17.4276, 15.1429, 14.9831, 19.3378, 12.5638, 13.9680, 13.4165, 17.9314, 15.6604], + [13.9790, 11.1951, 15.3990, 13.5877, 15.1370, 18.4508, 17.5596, 18.6859, 12.9309, 15.1975]]) + >>> adm_score # doctest: +NORMALIZE_WHITESPACE + tensor([[[0.5052, 0.3689, 0.3906, 0.2976], + [0.4771, 0.3502, 0.3273, 0.7300], + [0.4881, 0.4082, 0.2437, 0.2755], + [0.4948, 0.3176, 0.4467, 0.3533], + [0.5658, 0.4589, 0.3998, 0.3095], + [0.5100, 0.4339, 0.5263, 0.4693], + [0.5351, 0.4767, 0.4267, 0.2752], + [0.5028, 0.3350, 0.3247, 0.4020], + [0.5071, 0.3949, 0.3832, 0.3111], + [0.4666, 0.3583, 0.4521, 0.2777]], + [[0.4686, 0.4700, 0.2433, 0.4896], + [0.4952, 0.3658, 0.3985, 0.4379], + [0.5445, 0.3839, 0.4010, 0.2285], + [0.5038, 0.3151, 0.4543, 0.3893], + [0.4899, 0.4008, 0.4266, 0.3279], + [0.5109, 0.3921, 0.3264, 0.5778], + [0.5315, 0.3788, 0.3103, 0.6088], + [0.4607, 0.4334, 0.4077, 0.4407], + [0.5017, 0.3816, 0.2890, 0.3553], + [0.5284, 0.4586, 0.3681, 0.2760]]]) + >>> vif_score # doctest: +NORMALIZE_WHITESPACE + tensor([[[3.9898e-04, 2.7862e-02, 3.1761e-02, 4.5509e-02], + [1.6094e-04, 1.1518e-02, 2.0446e-02, 8.4023e-02], + [3.7477e-04, 7.8991e-03, 6.6453e-03, 5.7339e-04], + [6.7157e-04, 9.9271e-03, 4.6627e-02, 4.6662e-02], + [9.6011e-04, 1.3214e-02, 2.7918e-02, 1.6376e-02], + [6.7778e-04, 6.1006e-02, 9.8535e-02, 2.5073e-01], + [1.1227e-03, 3.3202e-02, 6.4757e-02, 8.6356e-02], + [1.2290e-04, 1.3186e-02, 3.0758e-02, 1.0355e-01], + [5.8098e-04, 3.3142e-03, 7.3332e-04, 5.8651e-04], + [2.5460e-04, 5.2497e-03, 1.7505e-02, 3.1771e-02]], + [[3.6456e-04, 1.4340e-02, 2.9021e-02, 1.1958e-01], + [1.5903e-04, 3.4139e-02, 1.1511e-01, 1.3284e-01], + [9.7763e-04, 9.1875e-03, 2.0795e-02, 7.2092e-02], + [4.7811e-04, 3.0047e-02, 5.6494e-02, 1.3386e-01], + [1.1665e-03, 1.7940e-02, 5.3484e-02, 1.5105e-01], + [9.6759e-04, 1.7089e-02, 2.1730e-02, 7.3590e-03], + [4.2169e-04, 1.2152e-02, 1.4762e-02, 5.8642e-02], + [1.5370e-04, 1.1013e-02, 1.0387e-02, 1.2726e-02], + [1.0364e-03, 2.8013e-02, 3.8921e-02, 7.5270e-02], + [8.9485e-04, 2.3440e-02, 4.1318e-02, 9.4294e-02]]]) >>> motion_score - tensor([ 0.0000, 17.7642, 18.1769, 17.5795, 15.6578, 16.0557]) + tensor([[ 0.0000, 15.9685, 15.9246, 15.7889, 17.3888, 19.0524, 13.7110, 16.0245, 16.1028, 15.5713], + [14.7679, 15.5407, 15.9964, 17.2818, 18.3270, 19.0149, 16.8640, 16.4841, 16.4464, 17.4890]]) """ if not _TORCH_VMAF_AVAILABLE: raise RuntimeError("vmaf-torch is not installed. Please install with `pip install torchmetrics[video]`.") - + b, f = preds.shape[0], preds.shape[2] orig_dtype, device = preds.dtype, preds.device preds_luma = calculate_luma(preds) target_luma = calculate_luma(target) @@ -113,16 +144,20 @@ def video_multi_method_assessment_fusion( vmaf = vmaf_torch.VMAF().to(device) score: Tensor = vmaf( - rearrange(target_luma, "b c t h w -> (b t) c h w"), rearrange(preds_luma, "b c t h w -> (b t) c h w") + rearrange(target_luma, "b c f h w -> (b f) c h w"), rearrange(preds_luma, "b c f h w -> (b f) c h w") ).to(orig_dtype) - - if elementary_features: - adm = vmaf.compute_adm_features( - rearrange(target_luma, "b c t h w -> (b t) c h w"), rearrange(preds_luma, "b c t h w -> (b t) c h w") - ) - vif = vmaf.compute_vif_features( - rearrange(target_luma, "b c t h w -> (b t) c h w"), rearrange(preds_luma, "b c t h w -> (b t) c h w") - ) - motion = vmaf.compute_motion(rearrange(target_luma, "b c t h w -> (b t) c h w")) - return score.squeeze(), adm, vif, motion.squeeze() - return score.squeeze() + score = rearrange(score, "(b f) 1 -> b f", b=b, f=f) + if not elementary_features: + return score + + adm = vmaf.compute_adm_features( + rearrange(target_luma, "b c f h w -> (b f) c h w"), rearrange(preds_luma, "b c f h w -> (b f) c h w") + ) + adm = rearrange(adm, "(b f) s -> b f s", b=b, f=f) # s=4 are the different scales + vif = vmaf.compute_vif_features( + rearrange(target_luma, "b c f h w -> (b f) c h w"), rearrange(preds_luma, "b c f h w -> (b f) c h w") + ) + vif = rearrange(vif, "(b f) s -> b f s", b=b, f=f) # s=4 are the different frequency bands + motion = vmaf.compute_motion(rearrange(target_luma, "b c f h w -> (b f) c h w")) + motion = rearrange(motion, "(b f) 1 -> b f", b=b, f=f) + return score, adm, vif, motion.squeeze() From 53f2bc53908e84ba2f66a302d1292bb4181fb9ba Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 29 Apr 2025 10:29:04 +0200 Subject: [PATCH 19/24] fix modular implementation --- src/torchmetrics/video/vmaf.py | 148 ++++++++++++++++----------------- 1 file changed, 74 insertions(+), 74 deletions(-) diff --git a/src/torchmetrics/video/vmaf.py b/src/torchmetrics/video/vmaf.py index 16247033201..489ab7c04e4 100644 --- a/src/torchmetrics/video/vmaf.py +++ b/src/torchmetrics/video/vmaf.py @@ -11,15 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Tuple, Union -import torch from torch import Tensor from torchmetrics.functional.video.vmaf import video_multi_method_assessment_fusion from torchmetrics.metric import Metric +from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.imports import _TORCH_VMAF_AVAILABLE -from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE class VideoMultiMethodAssessmentFusion(Metric): @@ -50,26 +49,30 @@ class VideoMultiMethodAssessmentFusion(Metric): As output of `forward` and `compute` the metric returns the following output - - ``vmaf`` (:class:`~torch.Tensor`): If `elementary_features` is set to False, a single tensor with the VMAF score - for each video in the batch. If `elementary_features` is True, a tuple of tensors is returned: - - - ``vmaf_score`` (:class:`~torch.Tensor`): The main VMAF score tensor - - ``adm_score`` (:class:`~torch.Tensor`): The Additive Detail Measure (ADM) score tensor, which measures - the preservation of details in the video. Shape is (batch * frames, 4) where the 4 values represent - different detail scales. Higher values indicate better detail preservation. - - ``vif_score`` (:class:`~torch.Tensor`): The Visual Information Fidelity (VIF) score tensor, which measures - the preservation of visual information. Shape is (batch * frames, 4) where the 4 values represent different - frequency bands. Higher values indicate better information preservation. - - ``motion_score`` (:class:`~torch.Tensor`): The motion score tensor, which measures the amount of motion - in the video. Shape is (batch * frames,). Higher values indicate more motion. + - ``vmaf`` (:class:`~torch.Tensor`): If `elementary_features` is False, returns a tensor with shape (batch, frame) + of VMAF score for each frame in each video. Higher scores indicate better quality, with typical values ranging + from 0 to 100. + + If `elementary_features` is True, returns a tuple of four tensors: + - vmaf_score: The main VMAF score tensor of shape (batch, frames) + - adm_score: The Additive Detail Measure (ADM) score tensor, which measures the preservation of details + in the video. Shape is (batch, frames, 4) where the 4 values represent different detail scales. + Higher values indicate better detail preservation. + - vif_score: The Visual Information Fidelity (VIF) score tensor, which measures the preservation of + visual information. Shape is (batch, frames, 4) where the 4 values represent different frequency bands. + Higher values indicate better information preservation. + - motion_score: The motion score tensor, which measures the amount of motion in the video. + Shape is (batch, frames). Higher values indicate more motion. Args: elementary_features: If True, returns the elementary features used by VMAF. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: - RuntimeError: If vmaf-torch is not installed. - ValueError: If `elementary_features` is not a boolean. + RuntimeError: + If vmaf-torch is not installed. + ValueError: + If `elementary_features` is not a boolean. Example: >>> from torch import rand @@ -78,27 +81,58 @@ class VideoMultiMethodAssessmentFusion(Metric): >>> target = rand(2, 3, 10, 32, 32) # 2 videos, 3 channels, 10 frames, 32x32 resolution >>> vmaf = VideoMultiMethodAssessmentFusion() >>> vmaf(preds, target) - tensor([12.6859, 15.1940, 14.6993, 14.9718, 19.1301, 17.1650]) + tensor([[ 7.0141, 17.4276, 15.1429, 14.9831, 19.3378, 12.5638, 13.9680, 13.4165, 17.9314, 15.6604], + [13.9790, 11.1951, 15.3990, 13.5877, 15.1370, 18.4508, 17.5596, 18.6859, 12.9309, 15.1975]]) >>> vmaf = VideoMultiMethodAssessmentFusion(elementary_features=True) >>> vmaf_score, adm_score, vif_score, motion_score = vmaf(preds, target) >>> vmaf_score - tensor([12.6859, 15.1940, 14.6993, 14.9718, 19.1301, 17.1650]) - >>> adm_score - tensor([[0.6258, 0.4526, 0.4360, 0.5100], - [0.6117, 0.4558, 0.4478, 0.5543], - [0.6253, 0.4867, 0.4116, 0.4412], - [0.6011, 0.4773, 0.4527, 0.5263], - [0.5830, 0.5209, 0.4050, 0.6781], - [0.6576, 0.5081, 0.4600, 0.6017]]) - >>> vif_score - tensor([[6.8940e-04, 3.5287e-02, 1.2094e-01, 6.7600e-01], - [7.8453e-04, 3.1258e-02, 6.3257e-02, 3.4321e-01], - [1.3337e-03, 2.8432e-02, 6.3114e-02, 4.6726e-01], - [1.8480e-04, 2.3861e-02, 1.5634e-01, 5.5803e-01], - [2.7257e-04, 3.4004e-02, 1.6240e-01, 6.9619e-01], - [1.2596e-03, 2.1799e-02, 1.0870e-01, 2.2582e-01]]) + tensor([[ 7.0141, 17.4276, 15.1429, 14.9831, 19.3378, 12.5638, 13.9680, 13.4165, 17.9314, 15.6604], + [13.9790, 11.1951, 15.3990, 13.5877, 15.1370, 18.4508, 17.5596, 18.6859, 12.9309, 15.1975]]) + >>> adm_score # doctest: +NORMALIZE_WHITESPACE + tensor([[[0.5052, 0.3689, 0.3906, 0.2976], + [0.4771, 0.3502, 0.3273, 0.7300], + [0.4881, 0.4082, 0.2437, 0.2755], + [0.4948, 0.3176, 0.4467, 0.3533], + [0.5658, 0.4589, 0.3998, 0.3095], + [0.5100, 0.4339, 0.5263, 0.4693], + [0.5351, 0.4767, 0.4267, 0.2752], + [0.5028, 0.3350, 0.3247, 0.4020], + [0.5071, 0.3949, 0.3832, 0.3111], + [0.4666, 0.3583, 0.4521, 0.2777]], + [[0.4686, 0.4700, 0.2433, 0.4896], + [0.4952, 0.3658, 0.3985, 0.4379], + [0.5445, 0.3839, 0.4010, 0.2285], + [0.5038, 0.3151, 0.4543, 0.3893], + [0.4899, 0.4008, 0.4266, 0.3279], + [0.5109, 0.3921, 0.3264, 0.5778], + [0.5315, 0.3788, 0.3103, 0.6088], + [0.4607, 0.4334, 0.4077, 0.4407], + [0.5017, 0.3816, 0.2890, 0.3553], + [0.5284, 0.4586, 0.3681, 0.2760]]]) + >>> vif_score # doctest: +NORMALIZE_WHITESPACE + tensor([[[3.9898e-04, 2.7862e-02, 3.1761e-02, 4.5509e-02], + [1.6094e-04, 1.1518e-02, 2.0446e-02, 8.4023e-02], + [3.7477e-04, 7.8991e-03, 6.6453e-03, 5.7339e-04], + [6.7157e-04, 9.9271e-03, 4.6627e-02, 4.6662e-02], + [9.6011e-04, 1.3214e-02, 2.7918e-02, 1.6376e-02], + [6.7778e-04, 6.1006e-02, 9.8535e-02, 2.5073e-01], + [1.1227e-03, 3.3202e-02, 6.4757e-02, 8.6356e-02], + [1.2290e-04, 1.3186e-02, 3.0758e-02, 1.0355e-01], + [5.8098e-04, 3.3142e-03, 7.3332e-04, 5.8651e-04], + [2.5460e-04, 5.2497e-03, 1.7505e-02, 3.1771e-02]], + [[3.6456e-04, 1.4340e-02, 2.9021e-02, 1.1958e-01], + [1.5903e-04, 3.4139e-02, 1.1511e-01, 1.3284e-01], + [9.7763e-04, 9.1875e-03, 2.0795e-02, 7.2092e-02], + [4.7811e-04, 3.0047e-02, 5.6494e-02, 1.3386e-01], + [1.1665e-03, 1.7940e-02, 5.3484e-02, 1.5105e-01], + [9.6759e-04, 1.7089e-02, 2.1730e-02, 7.3590e-03], + [4.2169e-04, 1.2152e-02, 1.4762e-02, 5.8642e-02], + [1.5370e-04, 1.1013e-02, 1.0387e-02, 1.2726e-02], + [1.0364e-03, 2.8013e-02, 3.8921e-02, 7.5270e-02], + [8.9485e-04, 2.3440e-02, 4.1318e-02, 9.4294e-02]]]) >>> motion_score - tensor([0.0000, 8.8821, 9.0885, 8.7898, 7.8289, 8.0279]) + tensor([[ 0.0000, 15.9685, 15.9246, 15.7889, 17.3888, 19.0524, 13.7110, 16.0245, 16.1028, 15.5713], + [14.7679, 15.5407, 15.9964, 17.2818, 18.3270, 19.0149, 16.8640, 16.4841, 16.4464, 17.4890]]) """ @@ -106,7 +140,7 @@ class VideoMultiMethodAssessmentFusion(Metric): higher_is_better: bool = True full_state_update: bool = False plot_lower_bound: float = 0.0 - plot_upper_bound: float = 100.0 # Updated to match VMAF score range + plot_upper_bound: float = 100.0 vmaf_score: List[Tensor] adm_features: List[Tensor] @@ -143,43 +177,9 @@ def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor]]: """Compute final VMAF score.""" if self.elementary_features: return ( - torch.cat(self.vmaf_score, dim=0), - torch.cat(self.adm_features, dim=0), - torch.cat(self.vif_features, dim=0), - torch.cat(self.motion, dim=0), + dim_zero_cat(self.vmaf_score), + dim_zero_cat(self.adm_features), + dim_zero_cat(self.vif_features), + dim_zero_cat(self.motion), ) - return torch.cat(self.vmaf_score, dim=0) - - def plot( - self, val: Optional[Union[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor]]] = None, ax: Optional[_AX_TYPE] = None - ) -> _PLOT_OUT_TYPE: - """Plot a single or multiple values from the metric. - - Args: - val: Either a single result from calling `metric.forward()` or `metric.compute()`, or the results - from multiple calls of `metric.forward()` or `metric.compute()`. If no value is provided, will - automatically call `metric.compute()` and plot that result. - ax: An matplotlib axis object. If provided will add plot to that axis - - Returns: - Figure and Axes object - - Raises: - ModuleNotFoundError: - If `matplotlib` is not installed - - .. plot:: - :scale: 75 - :caption: Example of plotting VMAF scores and elementary features. - - >>> # Example plotting a single value - >>> from torch import rand - >>> from torchmetrics.video import VideoMultiMethodAssessmentFusion - >>> metric = VideoMultiMethodAssessmentFusion(elementary_features=True) - >>> preds = rand(2, 3, 10, 32, 32) - >>> target = rand(2, 3, 10, 32, 32) - >>> metric.update(preds, target) - >>> fig_, ax_ = metric.plot() - - """ - return super().plot(val=val, ax=ax) + return dim_zero_cat(self.vmaf_score) From c8730621a93fd29a45ab68aa5890f654ceb202c7 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 29 Apr 2025 11:46:31 +0200 Subject: [PATCH 20/24] redo functional implementation --- src/torchmetrics/functional/video/vmaf.py | 146 +++++++++------------- 1 file changed, 59 insertions(+), 87 deletions(-) diff --git a/src/torchmetrics/functional/video/vmaf.py b/src/torchmetrics/functional/video/vmaf.py index 628d671b43d..11501e468be 100644 --- a/src/torchmetrics/functional/video/vmaf.py +++ b/src/torchmetrics/functional/video/vmaf.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple, Union +from typing import Dict, Union +import torch import vmaf_torch from einops import rearrange from torch import Tensor @@ -31,8 +32,8 @@ def calculate_luma(video: Tensor) -> Tensor: def video_multi_method_assessment_fusion( preds: Tensor, target: Tensor, - elementary_features: bool = False, -) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor]]: + features: bool = False, +) -> Union[Tensor, Dict[str, Tensor]]: """Calculates Video Multi-Method Assessment Fusion (VMAF) metric. VMAF is a full-reference video quality assessment algorithm that combines multiple quality assessment features @@ -56,22 +57,28 @@ def video_multi_method_assessment_fusion( with values in range [0, 1]. target: Video tensor of shape (batch, channels, frames, height, width). Expected to be in RGB format with values in range [0, 1]. - elementary_features: If True, returns the elementary features used by VMAF. + features: If True, all the elementary features (ADM, VIF, motion) are returned along with the VMAF score in + a dictionary. This corresponds to the output you would get from the VMAF command line tool with the `--csv` + option enabled. If False, only the VMAF score is returned as a tensor. Returns: - If `elementary_features` is False, returns a tensor with shape (batch, frame) of VMAF score for each frame in + If `features` is False, returns a tensor with shape (batch, frame) of VMAF score for each frame in each video. Higher scores indicate better quality, with typical values ranging from 0 to 100. - If `elementary_features` is True, returns a tuple of four tensors: - - vmaf_score: The main VMAF score tensor of shape (batch, frames) - - adm_score: The Additive Detail Measure (ADM) score tensor, which measures the preservation of details - in the video. Shape is (batch, frames, 4) where the 4 values represent different detail scales. - Higher values indicate better detail preservation. - - vif_score: The Visual Information Fidelity (VIF) score tensor, which measures the preservation of - visual information. Shape is (batch, frames, 4) where the 4 values represent different frequency bands. - Higher values indicate better information preservation. - - motion_score: The motion score tensor, which measures the amount of motion in the video. - Shape is (batch, frames). Higher values indicate more motion. + If `features` is True, returns a dictionary where each value is a (batch, frame) tensor of the + corresponding feature. The keys are: + - 'integer_motion2': Integer motion feature + - 'integer_motion': Integer motion feature + - 'integer_adm2': Integer ADM feature + - 'integer_adm_scale0': Integer ADM feature at scale 0 + - 'integer_adm_scale1': Integer ADM feature at scale 1 + - 'integer_adm_scale2': Integer ADM feature at scale 2 + - 'integer_adm_scale3': Integer ADM feature at scale 3 + - 'integer_vif_scale0': Integer VIF feature at scale 0 + - 'integer_vif_scale1': Integer VIF feature at scale 1 + - 'integer_vif_scale2': Integer VIF feature at scale 2 + - 'integer_vif_scale3': Integer VIF feature at scale 3 + - 'vmaf': VMAF score for each frame in each video Example: >>> import torch @@ -79,85 +86,50 @@ def video_multi_method_assessment_fusion( >>> preds = torch.rand(2, 3, 10, 32, 32) # 2 videos, 3 channels, 10 frames, 32x32 resolution >>> target = torch.rand(2, 3, 10, 32, 32) # 2 videos, 3 channels, 10 frames, 32x32 resolution >>> video_multi_method_assessment_fusion(preds, target) - tensor([[ 7.0141, 17.4276, 15.1429, 14.9831, 19.3378, 12.5638, 13.9680, 13.4165, 17.9314, 15.6604], - [13.9790, 11.1951, 15.3990, 13.5877, 15.1370, 18.4508, 17.5596, 18.6859, 12.9309, 15.1975]]) - >>> vmaf_score, adm_score, vif_score, motion_score = video_multi_method_assessment_fusion( - ... preds, target, elementary_features=True - ... ) - >>> vmaf_score - tensor([[ 7.0141, 17.4276, 15.1429, 14.9831, 19.3378, 12.5638, 13.9680, 13.4165, 17.9314, 15.6604], - [13.9790, 11.1951, 15.3990, 13.5877, 15.1370, 18.4508, 17.5596, 18.6859, 12.9309, 15.1975]]) - >>> adm_score # doctest: +NORMALIZE_WHITESPACE - tensor([[[0.5052, 0.3689, 0.3906, 0.2976], - [0.4771, 0.3502, 0.3273, 0.7300], - [0.4881, 0.4082, 0.2437, 0.2755], - [0.4948, 0.3176, 0.4467, 0.3533], - [0.5658, 0.4589, 0.3998, 0.3095], - [0.5100, 0.4339, 0.5263, 0.4693], - [0.5351, 0.4767, 0.4267, 0.2752], - [0.5028, 0.3350, 0.3247, 0.4020], - [0.5071, 0.3949, 0.3832, 0.3111], - [0.4666, 0.3583, 0.4521, 0.2777]], - [[0.4686, 0.4700, 0.2433, 0.4896], - [0.4952, 0.3658, 0.3985, 0.4379], - [0.5445, 0.3839, 0.4010, 0.2285], - [0.5038, 0.3151, 0.4543, 0.3893], - [0.4899, 0.4008, 0.4266, 0.3279], - [0.5109, 0.3921, 0.3264, 0.5778], - [0.5315, 0.3788, 0.3103, 0.6088], - [0.4607, 0.4334, 0.4077, 0.4407], - [0.5017, 0.3816, 0.2890, 0.3553], - [0.5284, 0.4586, 0.3681, 0.2760]]]) - >>> vif_score # doctest: +NORMALIZE_WHITESPACE - tensor([[[3.9898e-04, 2.7862e-02, 3.1761e-02, 4.5509e-02], - [1.6094e-04, 1.1518e-02, 2.0446e-02, 8.4023e-02], - [3.7477e-04, 7.8991e-03, 6.6453e-03, 5.7339e-04], - [6.7157e-04, 9.9271e-03, 4.6627e-02, 4.6662e-02], - [9.6011e-04, 1.3214e-02, 2.7918e-02, 1.6376e-02], - [6.7778e-04, 6.1006e-02, 9.8535e-02, 2.5073e-01], - [1.1227e-03, 3.3202e-02, 6.4757e-02, 8.6356e-02], - [1.2290e-04, 1.3186e-02, 3.0758e-02, 1.0355e-01], - [5.8098e-04, 3.3142e-03, 7.3332e-04, 5.8651e-04], - [2.5460e-04, 5.2497e-03, 1.7505e-02, 3.1771e-02]], - [[3.6456e-04, 1.4340e-02, 2.9021e-02, 1.1958e-01], - [1.5903e-04, 3.4139e-02, 1.1511e-01, 1.3284e-01], - [9.7763e-04, 9.1875e-03, 2.0795e-02, 7.2092e-02], - [4.7811e-04, 3.0047e-02, 5.6494e-02, 1.3386e-01], - [1.1665e-03, 1.7940e-02, 5.3484e-02, 1.5105e-01], - [9.6759e-04, 1.7089e-02, 2.1730e-02, 7.3590e-03], - [4.2169e-04, 1.2152e-02, 1.4762e-02, 5.8642e-02], - [1.5370e-04, 1.1013e-02, 1.0387e-02, 1.2726e-02], - [1.0364e-03, 2.8013e-02, 3.8921e-02, 7.5270e-02], - [8.9485e-04, 2.3440e-02, 4.1318e-02, 9.4294e-02]]]) - >>> motion_score - tensor([[ 0.0000, 15.9685, 15.9246, 15.7889, 17.3888, 19.0524, 13.7110, 16.0245, 16.1028, 15.5713], - [14.7679, 15.5407, 15.9964, 17.2818, 18.3270, 19.0149, 16.8640, 16.4841, 16.4464, 17.4890]]) + tensor([[ 7.0141, 17.4276, 15.1429, 14.9831, 19.3378, 12.5638, 13.9680, 13.4165, 17.9314, 16.0849], + [ 7.2760, 11.1951, 15.3990, 13.5877, 15.1370, 18.4508, 17.5596, 18.6859, 12.9309, 15.1975]]) + >>> vmaf_dict = video_multi_method_assessment_fusion(preds, target, features=True) + >>> # show a couple of features, more features are available + >>> vmaf_dict['vmaf'] + tensor([[ 7.0141, 17.4276, 15.1429, 14.9831, 19.3378, 12.5638, 13.9680, 13.4165, 17.9314, 16.0849], + [ 7.2760, 11.1951, 15.3990, 13.5877, 15.1370, 18.4508, 17.5596, 18.6859, 12.9309, 15.1975]]) + >>> vmaf_dict['integer_adm2'] + tensor([[0.3980, 0.4449, 0.3650, 0.4070, 0.4504, 0.4817, 0.4489, 0.3870, 0.4140, 0.4029], + [0.4134, 0.4216, 0.4050, 0.4138, 0.4209, 0.4378, 0.4333, 0.4358, 0.3846, 0.4168]]) + >>> vmaf_dict['integer_vif_scale0'] + tensor([[0.0004, 0.0002, 0.0004, 0.0007, 0.0010, 0.0007, 0.0011, 0.0001, 0.0006, 0.0003], + [0.0004, 0.0002, 0.0010, 0.0005, 0.0012, 0.0010, 0.0004, 0.0002, 0.0010, 0.0009]]) """ if not _TORCH_VMAF_AVAILABLE: raise RuntimeError("vmaf-torch is not installed. Please install with `pip install torchmetrics[video]`.") - b, f = preds.shape[0], preds.shape[2] + b = preds.shape[0] orig_dtype, device = preds.dtype, preds.device preds_luma = calculate_luma(preds) target_luma = calculate_luma(target) vmaf = vmaf_torch.VMAF().to(device) - score: Tensor = vmaf( - rearrange(target_luma, "b c f h w -> (b f) c h w"), rearrange(preds_luma, "b c f h w -> (b f) c h w") - ).to(orig_dtype) - score = rearrange(score, "(b f) 1 -> b f", b=b, f=f) - if not elementary_features: - return score - - adm = vmaf.compute_adm_features( - rearrange(target_luma, "b c f h w -> (b f) c h w"), rearrange(preds_luma, "b c f h w -> (b f) c h w") - ) - adm = rearrange(adm, "(b f) s -> b f s", b=b, f=f) # s=4 are the different scales - vif = vmaf.compute_vif_features( - rearrange(target_luma, "b c f h w -> (b f) c h w"), rearrange(preds_luma, "b c f h w -> (b f) c h w") - ) - vif = rearrange(vif, "(b f) s -> b f s", b=b, f=f) # s=4 are the different frequency bands - motion = vmaf.compute_motion(rearrange(target_luma, "b c f h w -> (b f) c h w")) - motion = rearrange(motion, "(b f) 1 -> b f", b=b, f=f) - return score, adm, vif, motion.squeeze() + # we need to compute the model for each video separately + if not features: + scores = [ + vmaf.compute_vmaf_score( + rearrange(target_luma[video], "c f h w -> f c h w"), rearrange(preds_luma[video], "c f h w -> f c h w") + ) + for video in range(b) + ] + return torch.cat(scores, dim=1).t().to(orig_dtype) + import pandas as pd # pandas is installed as a dependency of vmaf-torch + + scores_and_features = [ + vmaf.table( + rearrange(target_luma[video], "c f h w -> f c h w"), rearrange(preds_luma[video], "c f h w -> f c h w") + ) + for video in range(b) + ] + dfs = [scores_and_features[video].apply(pd.to_numeric, errors="coerce") for video in range(b)] + result = [ + {col: torch.tensor(dfs[video][col].values, dtype=orig_dtype) for col in dfs[video].columns if col != "Frame"} + for video in range(b) + ] + return {col: torch.stack([result[video][col] for video in range(b)]) for col in result[0]} From e5500b3a793cd7ec08bf7311cb473f59b1678778 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 29 Apr 2025 11:50:18 +0200 Subject: [PATCH 21/24] fix modular implementation --- src/torchmetrics/video/vmaf.py | 82 +++++++++++++++++++++++----------- 1 file changed, 57 insertions(+), 25 deletions(-) diff --git a/src/torchmetrics/video/vmaf.py b/src/torchmetrics/video/vmaf.py index 489ab7c04e4..992b3031dbe 100644 --- a/src/torchmetrics/video/vmaf.py +++ b/src/torchmetrics/video/vmaf.py @@ -143,43 +143,75 @@ class VideoMultiMethodAssessmentFusion(Metric): plot_upper_bound: float = 100.0 vmaf_score: List[Tensor] - adm_features: List[Tensor] - vif_features: List[Tensor] - motion: List[Tensor] - - def __init__(self, elementary_features: bool = False, **kwargs: Any) -> None: + integer_motion2: List[Tensor] + integer_motion: List[Tensor] + integer_adm2: List[Tensor] + integer_adm_scale0: List[Tensor] + integer_adm_scale1: List[Tensor] + integer_adm_scale2: List[Tensor] + integer_adm_scale3: List[Tensor] + integer_vif_scale0: List[Tensor] + integer_vif_scale1: List[Tensor] + integer_vif_scale2: List[Tensor] + integer_vif_scale3: List[Tensor] + + def __init__(self, features: bool = False, **kwargs: Any) -> None: super().__init__(**kwargs) if not _TORCH_VMAF_AVAILABLE: raise RuntimeError("vmaf-torch is not installed. Please install with `pip install torchmetrics[video]`.") - if not isinstance(elementary_features, bool): - raise ValueError("Argument `elementary_features` should be a boolean, but got {elementary_features}.") - self.elementary_features = elementary_features + if not isinstance(features, bool): + raise ValueError("Argument `elementary_features` should be a boolean, but got {features}.") + self.features = features self.add_state("vmaf_score", default=[], dist_reduce_fx="cat") - if self.elementary_features: - self.add_state("adm_features", default=[], dist_reduce_fx="cat") - self.add_state("vif_features", default=[], dist_reduce_fx="cat") - self.add_state("motion", default=[], dist_reduce_fx="cat") + if self.features: + self.add_state("integer_motion2", default=[], dist_reduce_fx="cat") + self.add_state("integer_motion", default=[], dist_reduce_fx="cat") + self.add_state("integer_adm2", default=[], dist_reduce_fx="cat") + self.add_state("integer_adm_scale0", default=[], dist_reduce_fx="cat") + self.add_state("integer_adm_scale1", default=[], dist_reduce_fx="cat") + self.add_state("integer_adm_scale2", default=[], dist_reduce_fx="cat") + self.add_state("integer_adm_scale3", default=[], dist_reduce_fx="cat") + self.add_state("integer_vif_scale0", default=[], dist_reduce_fx="cat") + self.add_state("integer_vif_scale1", default=[], dist_reduce_fx="cat") + self.add_state("integer_vif_scale2", default=[], dist_reduce_fx="cat") + self.add_state("integer_vif_scale3", default=[], dist_reduce_fx="cat") def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" - score = video_multi_method_assessment_fusion(preds, target, self.elementary_features) - if self.elementary_features: - self.vmaf_score.append(score[0]) - self.adm_features.append(score[1]) - self.vif_features.append(score[2]) - self.motion.append(score[3]) + score = video_multi_method_assessment_fusion(preds, target, self.features) + if self.features: + self.vmaf_score.append(score["vmaf"]) + self.integer_motion2.append(score["integer_motion2"]) + self.integer_motion.append(score["integer_motion"]) + self.integer_adm2.append(score["integer_adm2"]) + self.integer_adm_scale0.append(score["integer_adm_scale0"]) + self.integer_adm_scale1.append(score["integer_adm_scale1"]) + self.integer_adm_scale2.append(score["integer_adm_scale2"]) + self.integer_adm_scale3.append(score["integer_adm_scale3"]) + self.integer_vif_scale0.append(score["integer_vif_scale0"]) + self.integer_vif_scale1.append(score["integer_vif_scale1"]) + self.integer_vif_scale2.append(score["integer_vif_scale2"]) + self.integer_vif_scale3.append(score["integer_vif_scale3"]) else: self.vmaf_score.append(score) def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor]]: """Compute final VMAF score.""" - if self.elementary_features: - return ( - dim_zero_cat(self.vmaf_score), - dim_zero_cat(self.adm_features), - dim_zero_cat(self.vif_features), - dim_zero_cat(self.motion), - ) + if self.features: + return { + "vmaf": dim_zero_cat(self.vmaf_score), + "integer_motion2": dim_zero_cat(self.integer_motion2), + "integer_motion": dim_zero_cat(self.integer_motion), + "integer_adm2": dim_zero_cat(self.integer_adm2), + "integer_adm_scale0": dim_zero_cat(self.integer_adm_scale0), + "integer_adm_scale1": dim_zero_cat(self.integer_adm_scale1), + "integer_adm_scale2": dim_zero_cat(self.integer_adm_scale2), + "integer_adm_scale3": dim_zero_cat(self.integer_adm_scale3), + "integer_vif_scale0": dim_zero_cat(self.integer_vif_scale0), + "integer_vif_scale1": dim_zero_cat(self.integer_vif_scale1), + "integer_vif_scale2": dim_zero_cat(self.integer_vif_scale2), + "integer_vif_scale3": dim_zero_cat(self.integer_vif_scale3), + } return dim_zero_cat(self.vmaf_score) From 3358bd5419d5968da46751c37fc0cbf082ae8ed2 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 29 Apr 2025 11:54:32 +0200 Subject: [PATCH 22/24] fix modular implementation --- src/torchmetrics/video/vmaf.py | 107 ++++++++++++--------------------- 1 file changed, 37 insertions(+), 70 deletions(-) diff --git a/src/torchmetrics/video/vmaf.py b/src/torchmetrics/video/vmaf.py index 992b3031dbe..b820eaf098c 100644 --- a/src/torchmetrics/video/vmaf.py +++ b/src/torchmetrics/video/vmaf.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Tuple, Union +from typing import Any, Dict, List, Union from torch import Tensor @@ -49,30 +49,36 @@ class VideoMultiMethodAssessmentFusion(Metric): As output of `forward` and `compute` the metric returns the following output - - ``vmaf`` (:class:`~torch.Tensor`): If `elementary_features` is False, returns a tensor with shape (batch, frame) - of VMAF score for each frame in each video. Higher scores indicate better quality, with typical values ranging - from 0 to 100. - - If `elementary_features` is True, returns a tuple of four tensors: - - vmaf_score: The main VMAF score tensor of shape (batch, frames) - - adm_score: The Additive Detail Measure (ADM) score tensor, which measures the preservation of details - in the video. Shape is (batch, frames, 4) where the 4 values represent different detail scales. - Higher values indicate better detail preservation. - - vif_score: The Visual Information Fidelity (VIF) score tensor, which measures the preservation of - visual information. Shape is (batch, frames, 4) where the 4 values represent different frequency bands. - Higher values indicate better information preservation. - - motion_score: The motion score tensor, which measures the amount of motion in the video. - Shape is (batch, frames). Higher values indicate more motion. + - ``vmaf`` (:class:`~torch.Tensor`): If `features` is False, returns a tensor with shape (batch, frame) of VMAF + score for each frame in each video. Higher scores indicate better quality, with typical values ranging from + 0 to 100. + + If `features` is True, returns a dictionary where each value is a (batch, frame) tensor of the + corresponding feature. The keys are: + - 'integer_motion2': Integer motion feature + - 'integer_motion': Integer motion feature + - 'integer_adm2': Integer ADM feature + - 'integer_adm_scale0': Integer ADM feature at scale 0 + - 'integer_adm_scale1': Integer ADM feature at scale 1 + - 'integer_adm_scale2': Integer ADM feature at scale 2 + - 'integer_adm_scale3': Integer ADM feature at scale 3 + - 'integer_vif_scale0': Integer VIF feature at scale 0 + - 'integer_vif_scale1': Integer VIF feature at scale 1 + - 'integer_vif_scale2': Integer VIF feature at scale 2 + - 'integer_vif_scale3': Integer VIF feature at scale 3 + - 'vmaf': VMAF score for each frame in each video Args: - elementary_features: If True, returns the elementary features used by VMAF. + features: If True, all the elementary features (ADM, VIF, motion) are returned along with the VMAF score in + a dictionary. This corresponds to the output you would get from the VMAF command line tool with the `--csv` + option enabled. If False, only the VMAF score is returned as a tensor. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: RuntimeError: If vmaf-torch is not installed. ValueError: - If `elementary_features` is not a boolean. + If `features` is not a boolean. Example: >>> from torch import rand @@ -81,58 +87,19 @@ class VideoMultiMethodAssessmentFusion(Metric): >>> target = rand(2, 3, 10, 32, 32) # 2 videos, 3 channels, 10 frames, 32x32 resolution >>> vmaf = VideoMultiMethodAssessmentFusion() >>> vmaf(preds, target) - tensor([[ 7.0141, 17.4276, 15.1429, 14.9831, 19.3378, 12.5638, 13.9680, 13.4165, 17.9314, 15.6604], - [13.9790, 11.1951, 15.3990, 13.5877, 15.1370, 18.4508, 17.5596, 18.6859, 12.9309, 15.1975]]) - >>> vmaf = VideoMultiMethodAssessmentFusion(elementary_features=True) - >>> vmaf_score, adm_score, vif_score, motion_score = vmaf(preds, target) - >>> vmaf_score - tensor([[ 7.0141, 17.4276, 15.1429, 14.9831, 19.3378, 12.5638, 13.9680, 13.4165, 17.9314, 15.6604], - [13.9790, 11.1951, 15.3990, 13.5877, 15.1370, 18.4508, 17.5596, 18.6859, 12.9309, 15.1975]]) - >>> adm_score # doctest: +NORMALIZE_WHITESPACE - tensor([[[0.5052, 0.3689, 0.3906, 0.2976], - [0.4771, 0.3502, 0.3273, 0.7300], - [0.4881, 0.4082, 0.2437, 0.2755], - [0.4948, 0.3176, 0.4467, 0.3533], - [0.5658, 0.4589, 0.3998, 0.3095], - [0.5100, 0.4339, 0.5263, 0.4693], - [0.5351, 0.4767, 0.4267, 0.2752], - [0.5028, 0.3350, 0.3247, 0.4020], - [0.5071, 0.3949, 0.3832, 0.3111], - [0.4666, 0.3583, 0.4521, 0.2777]], - [[0.4686, 0.4700, 0.2433, 0.4896], - [0.4952, 0.3658, 0.3985, 0.4379], - [0.5445, 0.3839, 0.4010, 0.2285], - [0.5038, 0.3151, 0.4543, 0.3893], - [0.4899, 0.4008, 0.4266, 0.3279], - [0.5109, 0.3921, 0.3264, 0.5778], - [0.5315, 0.3788, 0.3103, 0.6088], - [0.4607, 0.4334, 0.4077, 0.4407], - [0.5017, 0.3816, 0.2890, 0.3553], - [0.5284, 0.4586, 0.3681, 0.2760]]]) - >>> vif_score # doctest: +NORMALIZE_WHITESPACE - tensor([[[3.9898e-04, 2.7862e-02, 3.1761e-02, 4.5509e-02], - [1.6094e-04, 1.1518e-02, 2.0446e-02, 8.4023e-02], - [3.7477e-04, 7.8991e-03, 6.6453e-03, 5.7339e-04], - [6.7157e-04, 9.9271e-03, 4.6627e-02, 4.6662e-02], - [9.6011e-04, 1.3214e-02, 2.7918e-02, 1.6376e-02], - [6.7778e-04, 6.1006e-02, 9.8535e-02, 2.5073e-01], - [1.1227e-03, 3.3202e-02, 6.4757e-02, 8.6356e-02], - [1.2290e-04, 1.3186e-02, 3.0758e-02, 1.0355e-01], - [5.8098e-04, 3.3142e-03, 7.3332e-04, 5.8651e-04], - [2.5460e-04, 5.2497e-03, 1.7505e-02, 3.1771e-02]], - [[3.6456e-04, 1.4340e-02, 2.9021e-02, 1.1958e-01], - [1.5903e-04, 3.4139e-02, 1.1511e-01, 1.3284e-01], - [9.7763e-04, 9.1875e-03, 2.0795e-02, 7.2092e-02], - [4.7811e-04, 3.0047e-02, 5.6494e-02, 1.3386e-01], - [1.1665e-03, 1.7940e-02, 5.3484e-02, 1.5105e-01], - [9.6759e-04, 1.7089e-02, 2.1730e-02, 7.3590e-03], - [4.2169e-04, 1.2152e-02, 1.4762e-02, 5.8642e-02], - [1.5370e-04, 1.1013e-02, 1.0387e-02, 1.2726e-02], - [1.0364e-03, 2.8013e-02, 3.8921e-02, 7.5270e-02], - [8.9485e-04, 2.3440e-02, 4.1318e-02, 9.4294e-02]]]) - >>> motion_score - tensor([[ 0.0000, 15.9685, 15.9246, 15.7889, 17.3888, 19.0524, 13.7110, 16.0245, 16.1028, 15.5713], - [14.7679, 15.5407, 15.9964, 17.2818, 18.3270, 19.0149, 16.8640, 16.4841, 16.4464, 17.4890]]) + tensor([[ 7.0141, 17.4276, 15.1429, 14.9831, 19.3378, 12.5638, 13.9680, 13.4165, 17.9314, 16.0849], + [ 7.2760, 11.1951, 15.3990, 13.5877, 15.1370, 18.4508, 17.5596, 18.6859, 12.9309, 15.1975]]) + >>> vmaf = VideoMultiMethodAssessmentFusion(features=True) + >>> vmaf_dict = vmaf(preds, target) + >>> vmaf_dict['vmaf'] + tensor([[ 7.0141, 17.4276, 15.1429, 14.9831, 19.3378, 12.5638, 13.9680, 13.4165, 17.9314, 16.0849], + [ 7.2760, 11.1951, 15.3990, 13.5877, 15.1370, 18.4508, 17.5596, 18.6859, 12.9309, 15.1975]]) + >>> vmaf_dict['integer_adm2'] + tensor([[0.3980, 0.4449, 0.3650, 0.4070, 0.4504, 0.4817, 0.4489, 0.3870, 0.4140, 0.4029], + [0.4134, 0.4216, 0.4050, 0.4138, 0.4209, 0.4378, 0.4333, 0.4358, 0.3846, 0.4168]]) + >>> vmaf_dict['integer_vif_scale0'] + tensor([[0.0004, 0.0002, 0.0004, 0.0007, 0.0010, 0.0007, 0.0011, 0.0001, 0.0006, 0.0003], + [0.0004, 0.0002, 0.0010, 0.0005, 0.0012, 0.0010, 0.0004, 0.0002, 0.0010, 0.0009]]) """ @@ -197,7 +164,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: else: self.vmaf_score.append(score) - def compute(self) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor, Tensor]]: + def compute(self) -> Union[Tensor, Dict[str, Tensor]]: """Compute final VMAF score.""" if self.features: return { From c1305c2bbc0fc77ebe4e535e61533250a4790fe4 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 29 Apr 2025 12:04:23 +0200 Subject: [PATCH 23/24] unittests --- tests/unittests/_helpers/testers.py | 6 +- tests/unittests/video/test_vmaf.py | 118 +++++++++++++--------------- 2 files changed, 58 insertions(+), 66 deletions(-) diff --git a/tests/unittests/_helpers/testers.py b/tests/unittests/_helpers/testers.py index cd93adc0b99..624c08ecc63 100644 --- a/tests/unittests/_helpers/testers.py +++ b/tests/unittests/_helpers/testers.py @@ -313,7 +313,11 @@ def _functional_test( **extra_kwargs, ) # assert it is the same - _assert_allclose(tm_result, ref_result, atol=atol) + if isinstance(ref_result, dict): + for key in ref_result: + _assert_allclose(tm_result, ref_result[key].numpy(), atol=atol, key=key) + else: + _assert_allclose(tm_result, ref_result, atol=atol) def _assert_dtype_support( diff --git a/tests/unittests/video/test_vmaf.py b/tests/unittests/video/test_vmaf.py index 03cb4117cf2..f6d6686d598 100644 --- a/tests/unittests/video/test_vmaf.py +++ b/tests/unittests/video/test_vmaf.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial import pytest import torch import vmaf_torch from einops import rearrange -from torchmetrics.functional.video.vmaf import video_multi_method_assessment_fusion +from torchmetrics.functional.video.vmaf import calculate_luma, video_multi_method_assessment_fusion from torchmetrics.utilities.imports import _TORCH_VMAF_AVAILABLE from torchmetrics.video import VideoMultiMethodAssessmentFusion from unittests import _Input @@ -27,107 +28,94 @@ seed_all(42) -def _reference_vmaf(preds, target, elementary_features=False): - """Reference implementation of VMAF metric.""" - device = preds.device - orig_dtype = preds.dtype +def _reference_vmaf(preds, target, features=False): + """Reference implementation of VMAF metric. - # Convert to float32 for processing - preds = (preds.clamp(-1, 1).to(torch.float32) + 1) / 2 # [-1, 1] -> [0, 1] - target = (target.clamp(-1, 1).to(torch.float32) + 1) / 2 # [-1, 1] -> [0, 1] - - # Calculate luma component - def calculate_luma(video): - r = video[:, :, 0, :, :] - g = video[:, :, 1, :, :] - b = video[:, :, 2, :, :] - return (0.299 * r + 0.587 * g + 0.114 * b).unsqueeze(1) * 255 # [0, 1] -> [0, 255] + This should preferably be replaced with the python version of the netflix library + https://github.com/Netflix/vmaf + but that requires it to be compiled on the system. + """ + b = preds.shape[0] + orig_dtype, device = preds.dtype, preds.device preds_luma = calculate_luma(preds) target_luma = calculate_luma(target) vmaf = vmaf_torch.VMAF().to(device) - score = vmaf(rearrange(target_luma, "b c t h w -> (b t) c h w"), rearrange(preds_luma, "b c t h w -> (b t) c h w")) - - if elementary_features: - adm = vmaf.compute_adm_features( - rearrange(target_luma, "b c t h w -> (b t) c h w"), rearrange(preds_luma, "b c t h w -> (b t) c h w") - ) - vif = vmaf.compute_vif_features( - rearrange(target_luma, "b c t h w -> (b t) c h w"), rearrange(preds_luma, "b c t h w -> (b t) c h w") + # we need to compute the model for each video separately + if not features: + scores = [ + vmaf.compute_vmaf_score( + rearrange(target_luma[video], "c f h w -> f c h w"), rearrange(preds_luma[video], "c f h w -> f c h w") + ) + for video in range(b) + ] + return torch.cat(scores, dim=1).t().to(orig_dtype) + import pandas as pd # pandas is installed as a dependency of vmaf-torch + + scores_and_features = [ + vmaf.table( + rearrange(target_luma[video], "c f h w -> f c h w"), rearrange(preds_luma[video], "c f h w -> f c h w") ) - motion = vmaf.compute_motion(rearrange(target_luma, "b c t h w -> (b t) c h w")) - return score.squeeze().to(orig_dtype), adm.to(orig_dtype), vif.to(orig_dtype), motion.squeeze().to(orig_dtype) - return score.squeeze().to(orig_dtype) + for video in range(b) + ] + dfs = [scores_and_features[video].apply(pd.to_numeric, errors="coerce") for video in range(b)] + result = [ + {col: torch.tensor(dfs[video][col].values, dtype=orig_dtype) for col in dfs[video].columns if col != "Frame"} + for video in range(b) + ] + return {col: torch.stack([result[video][col] for video in range(b)]) for col in result[0]} # Define inputs -NUM_BATCHES, BATCH_SIZE = 2, 4 +NUM_BATCHES, BATCH_SIZE, FRAMES = 2, 4, 10 _inputs = [] for size in [32, 64]: - preds = torch.rand(NUM_BATCHES, BATCH_SIZE, 3, 10, size, size) - target = torch.rand(NUM_BATCHES, BATCH_SIZE, 3, 10, size, size) + preds = torch.rand(NUM_BATCHES, BATCH_SIZE, 3, FRAMES, size, size) + target = torch.rand(NUM_BATCHES, BATCH_SIZE, 3, FRAMES, size, size) _inputs.append(_Input(preds=preds, target=target)) @pytest.mark.skipif(not _TORCH_VMAF_AVAILABLE, reason="test requires vmaf-torch") -@pytest.mark.parametrize("preds, target", [(i.preds, i.target) for i in _inputs]) +@pytest.mark.parametrize(("preds", "target"), [(i.preds, i.target) for i in _inputs]) +@pytest.mark.parametrize("features", [True, False]) class TestVMAF(MetricTester): """Test class for `VideoMultiMethodAssessmentFusion` metric.""" - atol = 1e-3 + atol = 1e-1 @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_vmaf(self, preds, target, ddp): + def test_vmaf(self, preds, target, features, ddp): """Test class implementation of metric.""" self.run_class_metric_test( ddp=ddp, preds=preds, target=target, metric_class=VideoMultiMethodAssessmentFusion, - reference_metric=_reference_vmaf, + reference_metric=partial(_reference_vmaf, features=features), + metric_args={"features": features}, ) - def test_vmaf_functional(self, preds, target): + def test_vmaf_functional(self, preds, target, features): """Test functional implementation of metric.""" self.run_functional_metric_test( preds=preds, target=target, metric_functional=video_multi_method_assessment_fusion, - reference_metric=_reference_vmaf, - ) - - def test_vmaf_elementary_features(self, preds, target): - """Test that elementary features are returned when requested.""" - # Test functional implementation - score = video_multi_method_assessment_fusion(preds[0], target[0], elementary_features=True) - breakpoint() - assert isinstance(score, tuple) - assert len(score) == 4 # VMAF score + ADM + VIF + motion - assert score[0].shape == (BATCH_SIZE,) # VMAF score shape - assert score[1].shape == (BATCH_SIZE, 4) # ADM shape - assert score[2].shape == (BATCH_SIZE, 4) # VIF shape - assert score[3].shape == (BATCH_SIZE,) # Motion shape - - def test_vmaf_half_cpu(self, preds, target): - """Test for half precision on CPU.""" - self.run_precision_test_cpu( - preds=preds, - target=target, - metric_module=VideoMultiMethodAssessmentFusion, - metric_functional=video_multi_method_assessment_fusion, + reference_metric=partial(_reference_vmaf, features=features), + metric_args={"features": features}, ) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") - def test_vmaf_half_gpu(self, preds, target): - """Test for half precision on GPU.""" - self.run_precision_test_gpu( - preds=preds, - target=target, - metric_module=VideoMultiMethodAssessmentFusion, - metric_functional=video_multi_method_assessment_fusion, - ) + def test_vmaf_features_shape(self, preds, target, features): + """Test that the shape of the features is correct.""" + if not features: + return + vmaf_dict = video_multi_method_assessment_fusion(preds[0], target[0], features=features) + for key in vmaf_dict: + assert vmaf_dict[key].shape == (BATCH_SIZE, FRAMES), ( + f"Shape of {key} is incorrect. Expected {(BATCH_SIZE, FRAMES)}, got {vmaf_dict[key].shape}" + ) @pytest.mark.skipif(_TORCH_VMAF_AVAILABLE, reason="test requires vmaf-torch") From 28b4def05524d43cdc6391c681658703b040d23a Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 29 Apr 2025 12:11:44 +0200 Subject: [PATCH 24/24] fix typing --- src/torchmetrics/video/vmaf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/video/vmaf.py b/src/torchmetrics/video/vmaf.py index b820eaf098c..34a85bfef02 100644 --- a/src/torchmetrics/video/vmaf.py +++ b/src/torchmetrics/video/vmaf.py @@ -148,7 +148,7 @@ def __init__(self, features: bool = False, **kwargs: Any) -> None: def update(self, preds: Tensor, target: Tensor) -> None: """Update state with predictions and targets.""" score = video_multi_method_assessment_fusion(preds, target, self.features) - if self.features: + if self.features and isinstance(score, dict): self.vmaf_score.append(score["vmaf"]) self.integer_motion2.append(score["integer_motion2"]) self.integer_motion.append(score["integer_motion"]) @@ -161,7 +161,7 @@ def update(self, preds: Tensor, target: Tensor) -> None: self.integer_vif_scale1.append(score["integer_vif_scale1"]) self.integer_vif_scale2.append(score["integer_vif_scale2"]) self.integer_vif_scale3.append(score["integer_vif_scale3"]) - else: + elif isinstance(score, Tensor): self.vmaf_score.append(score) def compute(self) -> Union[Tensor, Dict[str, Tensor]]: