@@ -95,9 +95,6 @@ def __init__(
9595 dtype = None ,
9696 name = None ,
9797 ):
98- if dtype is None :
99- dtype = "int64" if output_mode == "int" else backend .floatx ()
100-
10198 super ().__init__ (name = name , dtype = dtype )
10299
103100 if sparse and not backend .SUPPORTS_SPARSE_TENSORS :
@@ -155,6 +152,10 @@ def __init__(
155152 def input_dtype (self ):
156153 return backend .floatx ()
157154
155+ @property
156+ def output_dtype (self ):
157+ return self .compute_dtype if self .output_mode != "int" else "int32"
158+
158159 def adapt (self , data , steps = None ):
159160 """Computes bin boundaries from quantiles in a input dataset.
160161
@@ -213,7 +214,7 @@ def reset_state(self):
213214 self .summary = np .array ([[], []], dtype = "float32" )
214215
215216 def compute_output_spec (self , inputs ):
216- return backend .KerasTensor (shape = inputs .shape , dtype = self .compute_dtype )
217+ return backend .KerasTensor (shape = inputs .shape , dtype = self .output_dtype )
217218
218219 def load_own_variables (self , store ):
219220 if len (store ) == 1 :
@@ -234,7 +235,7 @@ def call(self, inputs):
234235 indices ,
235236 output_mode = self .output_mode ,
236237 depth = len (self .bin_boundaries ) + 1 ,
237- dtype = self .compute_dtype ,
238+ dtype = self .output_dtype ,
238239 sparse = self .sparse ,
239240 backend_module = self .backend ,
240241 )
0 commit comments