1010
1111use crate :: tensor:: OrtexTensor ;
1212use crate :: utils:: map_opt_level;
13- use std:: convert:: { Into , TryFrom } ;
13+ use std:: convert:: { TryFrom , TryInto } ;
1414
15- use ort:: { Error , ExecutionProviderDispatch , Session , Value } ;
15+ use ort:: { Error , ExecutionProviderDispatch , Session } ;
1616use rustler:: resource:: ResourceArc ;
1717use rustler:: Atom ;
1818
@@ -37,15 +37,11 @@ pub fn init(
3737) -> Result < OrtexModel , Error > {
3838 // TODO: send tracing logs to erlang/elixir _somehow_
3939 // tracing_subscriber::fmt::init();
40- ort:: init ( )
41- . with_execution_providers ( & eps)
42- . with_name ( "ortex-model" )
43- . commit ( ) ?;
4440
4541 let session = Session :: builder ( ) ?
46- . with_execution_providers ( & eps) ?
4742 . with_optimization_level ( map_opt_level ( opt) ) ?
48- . with_model_from_file ( model_path) ?;
43+ . with_execution_providers ( eps) ?
44+ . commit_from_file ( model_path) ?;
4945
5046 let state = OrtexModel { session } ;
5147 Ok ( state)
@@ -88,21 +84,22 @@ pub fn run(
8884 inputs : Vec < ResourceArc < OrtexTensor > > ,
8985) -> Result < Vec < ( ResourceArc < OrtexTensor > , Vec < usize > , Atom , usize ) > , Error > {
9086 // TODO: can we handle an error more elegantly than just .unwrap()?
91- let final_input: Vec < Value > = inputs
92- . into_iter ( )
93- . map ( |x| Value :: try_from ( & * x) . unwrap ( ) )
94- . collect ( ) ;
87+
88+ let mut ortified_inputs: Vec < ort:: SessionInputValue > = Vec :: new ( ) ;
89+ for input in inputs {
90+ let derefed_input: & OrtexTensor = & input;
91+ let v: ort:: SessionInputValue = derefed_input. try_into ( ) ?;
92+ ortified_inputs. push ( v) ;
93+ }
9594
9695 // Grab the session and run a forward pass with it
97- let session = & model. session ;
96+ let session: & ort :: Session = & model. session ;
9897
9998 // Construct a Vec of ModelOutput enums based on the DynOrtTensor data type
100- let outputs = session. run ( & final_input[ ..] ) ?;
101-
99+ let outputs = session. run ( & ortified_inputs[ ..] ) ?;
102100 outputs
103101 . iter ( )
104102 . map ( |( _name, val) | {
105- let val: & Value = val;
106103 let ortextensor: OrtexTensor = OrtexTensor :: try_from ( val) ?;
107104 let shape = ortextensor. shape ( ) ;
108105 let ( dtype, bits) = ortextensor. dtype ( ) ;
0 commit comments