Skip to content

Commit 175da8d

Browse files
authored
Fix custom init sorting script (huggingface#16864)
1 parent 67ed0e4 commit 175da8d

File tree

3 files changed

+21
-11
lines changed

3 files changed

+21
-11
lines changed

src/transformers/__init__.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -446,27 +446,24 @@
446446
# tokenizers-backed objects
447447
if is_tokenizers_available():
448448
# Fast tokenizers
449-
_import_structure["models.realm"].append("RealmTokenizerFast")
450-
_import_structure["models.xglm"].append("XGLMTokenizerFast")
451-
_import_structure["models.fnet"].append("FNetTokenizerFast")
452-
_import_structure["models.roformer"].append("RoFormerTokenizerFast")
453-
_import_structure["models.clip"].append("CLIPTokenizerFast")
454-
_import_structure["models.convbert"].append("ConvBertTokenizerFast")
455-
_import_structure["models.blenderbot_small"].append("BlenderbotSmallTokenizerFast")
456449
_import_structure["models.albert"].append("AlbertTokenizerFast")
457450
_import_structure["models.bart"].append("BartTokenizerFast")
458451
_import_structure["models.barthez"].append("BarthezTokenizerFast")
459452
_import_structure["models.bert"].append("BertTokenizerFast")
460453
_import_structure["models.big_bird"].append("BigBirdTokenizerFast")
461454
_import_structure["models.blenderbot"].append("BlenderbotTokenizerFast")
455+
_import_structure["models.blenderbot_small"].append("BlenderbotSmallTokenizerFast")
462456
_import_structure["models.camembert"].append("CamembertTokenizerFast")
457+
_import_structure["models.clip"].append("CLIPTokenizerFast")
458+
_import_structure["models.convbert"].append("ConvBertTokenizerFast")
463459
_import_structure["models.deberta"].append("DebertaTokenizerFast")
464460
_import_structure["models.deberta_v2"].append("DebertaV2TokenizerFast")
465461
_import_structure["models.distilbert"].append("DistilBertTokenizerFast")
466462
_import_structure["models.dpr"].extend(
467463
["DPRContextEncoderTokenizerFast", "DPRQuestionEncoderTokenizerFast", "DPRReaderTokenizerFast"]
468464
)
469465
_import_structure["models.electra"].append("ElectraTokenizerFast")
466+
_import_structure["models.fnet"].append("FNetTokenizerFast")
470467
_import_structure["models.funnel"].append("FunnelTokenizerFast")
471468
_import_structure["models.gpt2"].append("GPT2TokenizerFast")
472469
_import_structure["models.herbert"].append("HerbertTokenizerFast")
@@ -483,13 +480,16 @@
483480
_import_structure["models.mt5"].append("MT5TokenizerFast")
484481
_import_structure["models.openai"].append("OpenAIGPTTokenizerFast")
485482
_import_structure["models.pegasus"].append("PegasusTokenizerFast")
483+
_import_structure["models.realm"].append("RealmTokenizerFast")
486484
_import_structure["models.reformer"].append("ReformerTokenizerFast")
487485
_import_structure["models.rembert"].append("RemBertTokenizerFast")
488486
_import_structure["models.retribert"].append("RetriBertTokenizerFast")
489487
_import_structure["models.roberta"].append("RobertaTokenizerFast")
488+
_import_structure["models.roformer"].append("RoFormerTokenizerFast")
490489
_import_structure["models.splinter"].append("SplinterTokenizerFast")
491490
_import_structure["models.squeezebert"].append("SqueezeBertTokenizerFast")
492491
_import_structure["models.t5"].append("T5TokenizerFast")
492+
_import_structure["models.xglm"].append("XGLMTokenizerFast")
493493
_import_structure["models.xlm_roberta"].append("XLMRobertaTokenizerFast")
494494
_import_structure["models.xlnet"].append("XLNetTokenizerFast")
495495
_import_structure["tokenization_utils_fast"] = ["PreTrainedTokenizerFast"]

src/transformers/models/marian/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848

4949
if is_flax_available():
5050
_import_structure["modeling_flax_marian"] = ["FlaxMarianModel", "FlaxMarianMTModel", "FlaxMarianPreTrainedModel"]
51+
5152
if TYPE_CHECKING:
5253
from .configuration_marian import MARIAN_PRETRAINED_CONFIG_ARCHIVE_MAP, MarianConfig, MarianOnnxConfig
5354

utils/custom_init_isort.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -183,11 +183,20 @@ def sort_imports(file, check_only=True):
183183
# Check if the block contains some `_import_structure`s thingy to sort.
184184
block = main_blocks[block_idx]
185185
block_lines = block.split("\n")
186-
if len(block_lines) < 3 or "_import_structure" not in "".join(block_lines[:2]):
186+
187+
# Get to the start of the imports.
188+
line_idx = 0
189+
while line_idx < len(block_lines) and "_import_structure" not in block_lines[line_idx]:
190+
# Skip dummy import blocks
191+
if "import dummy" in block_lines[line_idx]:
192+
line_idx = len(block_lines)
193+
else:
194+
line_idx += 1
195+
if line_idx >= len(block_lines):
187196
continue
188197

189-
# Ignore first and last line: they don't contain anything.
190-
internal_block_code = "\n".join(block_lines[1:-1])
198+
# Ignore beginning and last line: they don't contain anything.
199+
internal_block_code = "\n".join(block_lines[line_idx:-1])
191200
indent = get_indent(block_lines[1])
192201
# Slit the internal block into blocks of indent level 1.
193202
internal_blocks = split_code_in_indented_blocks(internal_block_code, indent_level=indent)
@@ -211,7 +220,7 @@ def sort_imports(file, check_only=True):
211220
count += 1
212221

213222
# And we put our main block back together with its first and last line.
214-
main_blocks[block_idx] = "\n".join([block_lines[0]] + reorderded_blocks + [block_lines[-1]])
223+
main_blocks[block_idx] = "\n".join(block_lines[:line_idx] + reorderded_blocks + [block_lines[-1]])
215224

216225
if code != "\n".join(main_blocks):
217226
if check_only:

0 commit comments

Comments
 (0)