Commit 8dbc07de authored by Jean-Marie Lepioufle's avatar Jean-Marie Lepioufle
Browse files

changes according to export_onnx

parent 00066390
......@@ -13,6 +13,7 @@ def k_fold_train_function(params:dict):
params_data=params['data']
params_model=params['model']
params_train=params['training_param']
params_export_onnx=params['export_onnx']
kfolds = params_train['kfolds']
epochs = params_train['epochs']
batch_size=params_train['batch_size']
......@@ -51,7 +52,7 @@ def k_fold_train_function(params:dict):
res = pd.concat([res,loss_train_tmp,loss_valid_tmp],axis=1)
loss_res = pd.concat([loss_res,res],axis=0)
model_.save(resdir,"fold_"+str(fold))
model_.save_onnx(resdir,"fold_"+str(fold),train_fold)
model_.export_onnx(resdir,"fold_"+str(fold),train_fold,params_export_onnx)
# save the class_ts_x data
f_data = open(data_filename, 'wb')
pickle.dump(data, f_data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment