Skip to content

Commit 3e513f2

Browse files
wangtzcopybara-github
authored andcommitted
Model Maker Audio: avoid sndfile missing issue if resampling is not required.
Currently if the sysmte has no sndfile, an error occurs when you simply import model_maker so it affects all tasks. PiperOrigin-RevId: 373936670
1 parent 414b1ca commit 3e513f2

File tree

2 files changed

+55
-7
lines changed

2 files changed

+55
-7
lines changed

tensorflow_examples/lite/model_maker/core/data_util/audio_dataloader.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,20 @@
2020
import os
2121
import random
2222

23-
import librosa
2423
import pandas as pd
2524
import tensorflow as tf
2625
from tensorflow_examples.lite.model_maker.core.api.api_util import mm_export
2726
from tensorflow_examples.lite.model_maker.core.data_util import dataloader
2827
from tensorflow_examples.lite.model_maker.core.task.model_spec import audio_spec
2928

29+
error_import_librosa = None
30+
try:
31+
import librosa # pylint: disable=g-import-not-at-top
32+
ENABLE_RESAMPLE = True
33+
except (OSError, ImportError) as _error_import_librosa: # pylint: disable=invalid-name
34+
ENABLE_RESAMPLE = False
35+
error_import_librosa = _error_import_librosa
36+
3037

3138
class ExamplesHelper(object):
3239
"""Helper class for matching examples and labels."""
@@ -306,8 +313,15 @@ def _load_wav(filepath, label):
306313
# This is a eager mode numpy_function. It can be converted to a tf.function
307314
# using https://www.tensorflow.org/io/api_docs/python/tfio/audio/resample
308315
def _resample_numpy(waveform, sample_rate, label):
309-
waveform = librosa.resample(
310-
waveform, orig_sr=sample_rate, target_sr=spec.target_sample_rate)
316+
if ENABLE_RESAMPLE:
317+
waveform = librosa.resample(
318+
waveform, orig_sr=sample_rate, target_sr=spec.target_sample_rate)
319+
else:
320+
error_message = (
321+
'Failed to import librosa. You might be missing sndfile, which '
322+
'can be installed via `sudo apt-get install libsndfile1` on '
323+
'Ubuntu/Debian.')
324+
raise RuntimeError(error_message) from error_import_librosa
311325
return waveform, label
312326

313327
@tf.function

tensorflow_examples/lite/model_maker/core/data_util/audio_dataloader_test.py

+38-4
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818

1919
import csv
2020
import os
21+
import shutil
2122
import unittest
2223

2324
import numpy as np
2425
from scipy.io import wavfile
25-
2626
import tensorflow.compat.v2 as tf
2727
from tensorflow_examples.lite.model_maker.core.data_util import audio_dataloader
2828
from tensorflow_examples.lite.model_maker.core.task.model_spec import audio_spec
@@ -106,7 +106,7 @@ class Base(tf.test.TestCase):
106106
def _get_folder_path(self, sub_folder_name):
107107
folder_path = os.path.join(self.get_temp_dir(), sub_folder_name)
108108
if os.path.exists(folder_path):
109-
return
109+
shutil.rmtree(folder_path)
110110
tf.compat.v1.logging.info('Test path: %s', folder_path)
111111
os.mkdir(folder_path)
112112
return folder_path
@@ -117,7 +117,7 @@ def _get_folder_path(self, sub_folder_name):
117117
class LoadFromESC50Test(Base):
118118

119119
def test_from_esc50(self):
120-
folder_path = self._get_folder_path('test_examples_helper')
120+
folder_path = self._get_folder_path('test_from_esc50')
121121

122122
headers = [
123123
'filename', 'fold', 'target', 'category', 'esc10', 'src_file', 'take'
@@ -215,7 +215,7 @@ def fullpath(name):
215215
class LoadFromFolderTest(Base):
216216

217217
def test_spec(self):
218-
folder_path = self._get_folder_path('test_examples_helper')
218+
folder_path = self._get_folder_path('test_spec')
219219
write_sample(folder_path, 'unknown', '2s.wav', 44100, 2, value=1)
220220

221221
spec = audio_spec.YAMNetSpec()
@@ -231,6 +231,40 @@ def test_no_audio_files_found(self):
231231
spec = MockSpec(model_dir=folder_path)
232232
audio_dataloader.DataLoader.from_folder(spec, folder_path)
233233

234+
def test_failed_librosa_imoprt(self):
235+
# Temporarily disable resampling.
236+
audio_dataloader.ENABLE_RESAMPLE = False
237+
238+
# Pretend a real import failure.
239+
try:
240+
import inexistent_package # pylint: disable=g-import-not-at-top,unused-import
241+
except (OSError, ImportError) as e:
242+
audio_dataloader.error_import_librosa = e
243+
244+
try:
245+
folder_path = self._get_folder_path('test_failed_librosa_imoprt')
246+
247+
# No error occured if resampling is not needed.
248+
write_sample(folder_path, 'background', '1s.wav', 44100, 1, value=0)
249+
spec = MockSpec(model_dir=folder_path)
250+
loader = audio_dataloader.DataLoader.from_folder(spec, folder_path)
251+
self.assertEqual(len(loader), 1)
252+
self.assertEqual(len(list(loader.gen_dataset())), 1)
253+
254+
# Error occured when resampling is needed.
255+
write_sample(folder_path, 'command0', '1.8s.wav', 4410, 1.8, value=5)
256+
spec = MockSpec(model_dir=folder_path)
257+
loader = audio_dataloader.DataLoader.from_folder(spec, folder_path)
258+
self.assertEqual(len(loader), 2)
259+
with self.assertRaisesRegexp(tf.errors.UnknownError,
260+
'sudo apt-get install libsndfile1'):
261+
_ = list(loader.gen_dataset())
262+
263+
finally:
264+
# Set it back
265+
audio_dataloader.ENABLE_RESAMPLE = True
266+
audio_dataloader.error_import_librosa = None
267+
234268
def test_from_folder(self):
235269
folder_path = self._get_folder_path('test_from_folder')
236270
write_sample(folder_path, 'background', '2s.wav', 44100, 2, value=0)

0 commit comments

Comments
 (0)