Commit 461d5caf authored by Jean-Marie Lepioufle's avatar Jean-Marie Lepioufle
Browse files

add param to export_onnx

parent 957a8477
......@@ -50,11 +50,17 @@ class class_model():
with open(params_path, "w+") as p:
json.dump(self.dump, p)
def save_onnx(self, path: str, name:str, class_ts_) -> None:
if not os.path.exists(path):
os.mkdir(path)
model_path = os.path.join(path, name + "_model.onnx")
if class_ts_.nb_input() == 1:
torch.onnx.export(self.model, class_ts_[0][0], model_path)
elif class_ts_.nb_input() > 1:
torch.onnx.export(self.model, tuple(class_ts_[0][0]), model_path)
def export_onnx(self, path: str, name:str, class_ts_,params:dict):
if params is not None:
if params['param']['opset_version'] is not None:
opset_version = params['param']['opset_version']
else:
opset_version = 9
# further params ...
if not os.path.exists(path):
os.mkdir(path)
model_path = os.path.join(path, name + "_model.onnx")
if class_ts_.nb_input() == 1:
torch.onnx.export(self.model, class_ts_[0][0], model_path)
elif class_ts_.nb_input() > 1:
torch.onnx.export(self.model, class_ts_.get_rand_input(), model_path,opset_version=opset_version)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment