Skip to content

Commit a3a420a

Browse files
authored
Merge pull request #46 from solegalli/recreate_load_boston_function
[MRG] replace boston import from sklearn
2 parents 03ec379 + 45791c2 commit a3a420a

File tree

3 files changed

+543
-1
lines changed

3 files changed

+543
-1
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
*.DS_Store
12
*.pyc
23
*.joblib
34
*egg-info

tests/conftest.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# -*- coding: utf-8 -*-
2+
from os.path import dirname
3+
from os.path import join
4+
import csv
25
import pytest
36
import numpy as np
4-
from sklearn.datasets import fetch_20newsgroups, load_boston, load_iris
7+
from sklearn.datasets import fetch_20newsgroups, load_iris
58
from sklearn.utils import shuffle
69

710
NEWSGROUPS_CATEGORIES = [
@@ -48,6 +51,36 @@ def newsgroups_train_big():
4851
def newsgroups_train_binary_big():
4952
return _get_newsgroups(binary=True, remove_chrome=True, size=1000)
5053

54+
class Bunch(dict):
55+
"""Container object for datasets: dictionary-like object that
56+
exposes its keys as attributes."""
57+
58+
def __init__(self, **kwargs):
59+
dict.__init__(self, kwargs)
60+
self.__dict__ = self
61+
62+
def load_boston():
63+
module_path = dirname(__file__)
64+
65+
data_file_name = join(module_path, 'data', 'boston_house_prices.csv')
66+
with open(data_file_name) as f:
67+
data_file = csv.reader(f)
68+
temp = next(data_file)
69+
n_samples = int(temp[0])
70+
n_features = int(temp[1])
71+
data = np.empty((n_samples, n_features))
72+
target = np.empty((n_samples,))
73+
temp = next(data_file) # names of features
74+
feature_names = np.array(temp)
75+
76+
for i, d in enumerate(data_file):
77+
data[i] = np.asarray(d[:-1], dtype=float)
78+
target[i] = np.asarray(d[-1], dtype=float)
79+
80+
return Bunch(data=data,
81+
target=target,
82+
# last column is target value
83+
feature_names=feature_names[:-1])
5184

5285
@pytest.fixture(scope="session")
5386
def boston_train(size=SIZE):

0 commit comments

Comments
 (0)