diff --git a/biomes/model.py b/biomes/model.py index 8ef0ae7..12dd1b3 100644 --- a/biomes/model.py +++ b/biomes/model.py @@ -121,7 +121,7 @@ class Model(): # map_callback = MapHistory() extra_params = {} - if self.class_weight: + if self.class_weight.any(): extra_params['class_weight'] = self.class_weight out = self.model.fit(