diff --git a/dpgen/collect/collect.py b/dpgen/collect/collect.py index a8d5b4060..5e48c7987 100644 --- a/dpgen/collect/collect.py +++ b/dpgen/collect/collect.py @@ -24,8 +24,12 @@ def collect_data(target_folder, param_file, output, init_data = [] init_data_prefix = jdata.get('init_data_prefix', '') init_data_sys = jdata.get('init_data_sys', []) - for ii in init_data_sys: - init_data.append(dpdata.LabeledSystem(os.path.join(init_data_prefix, ii), fmt='deepmd/npy')) + for ii in init_data_sys: + if jdata.get('init_multi_systems', False): + for single_sys in os.listdir(os.path.join(init_data_prefix, ii)): + init_data.append(dpdata.LabeledSystem(os.path.join(init_data_prefix, ii, single_sys), fmt='deepmd/npy')) + else: + init_data.append(dpdata.LabeledSystem(os.path.join(init_data_prefix, ii), fmt='deepmd/npy')) # collect systems from iter dirs coll_data = {} numb_sys = len(sys_configs)