Skip to content

Commit 0c4bc4f

Browse files
authored
feat: fix travis ci test (kleveross#75)
1 parent 209266e commit 0c4bc4f

File tree

3 files changed

+11
-10
lines changed

3 files changed

+11
-10
lines changed

.isort.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
[settings]
2-
known_third_party = data,grpc,numpy,setuptools,torch
2+
known_third_party = grpc,numpy,setuptools,torch
33
multi_line_output=3
44
include_trailing_comma=True

test/deprecated-tests/pytorch_test.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch.optim.lr_scheduler import StepLR
1414

1515
from ftlib import BasicFTLib
16-
from ftlib.ftlib_status import FTAllReduceStatus
16+
from ftlib.ftlib_status import FTCollectiveStatus
1717

1818
root_dir = os.path.join(os.path.dirname(__file__), os.path.pardir)
1919
sys.path.insert(0, os.path.abspath(root_dir))
@@ -106,15 +106,15 @@ def forward(self, x):
106106
continue
107107
else:
108108
res = ftlib.wait_gradients_ready(model)
109-
if res == FTAllReduceStatus.NO_NEED:
109+
if res == FTCollectiveStatus.NO_NEED:
110110
logging.critical(
111111
"cannot use average_gradient when there is no need"
112112
)
113113
exit(2)
114-
if res == FTAllReduceStatus.SUCCESS:
114+
if res == FTCollectiveStatus.SUCCESS:
115115
logging.info("average succeed")
116116
optimizer.step()
117-
if res == FTAllReduceStatus.ABORT:
117+
if res == FTCollectiveStatus.ABORT:
118118
logging.info("average failed, abort")
119119
continue
120120
scheduler.step()

test/deprecated-tests/tricky-data/pytorch-gossip-tricky-data.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010
import torch.nn as nn
1111
import torch.nn.functional as F
1212
import torch.optim as optim
13-
from data import TrickySampler
1413
from torch.optim.lr_scheduler import StepLR
1514

1615
from ftlib import BasicFTLib
17-
from ftlib.ftlib_status import FTAllReduceStatus
16+
from ftlib.ftlib_status import FTCollectiveStatus
17+
18+
from .data import TrickySampler
1819

1920
LOGLEVEL = os.environ.get("LOGLEVEL", "WARNING").upper()
2021
logging.basicConfig(level=LOGLEVEL)
@@ -145,15 +146,15 @@ def forward(self, x):
145146
continue
146147
else:
147148
res = ftlib.wait_gradients_ready(model)
148-
if res == FTAllReduceStatus.NO_NEED:
149+
if res == FTCollectiveStatus.NO_NEED:
149150
logging.critical(
150151
"cannot use average_gradient when there is no need"
151152
)
152153
exit(2)
153-
elif res == FTAllReduceStatus.SUCCESS:
154+
elif res == FTCollectiveStatus.SUCCESS:
154155
logging.info("average succeed")
155156
optimizer.step()
156-
elif res == FTAllReduceStatus.ABORT:
157+
elif res == FTCollectiveStatus.ABORT:
157158
logging.info("average failed, abort")
158159
continue
159160
else:

0 commit comments

Comments
 (0)