diff --git a/tester/api_config/config_analyzer.py b/tester/api_config/config_analyzer.py index 3fd536df..936ec1fc 100644 --- a/tester/api_config/config_analyzer.py +++ b/tester/api_config/config_analyzer.py @@ -1998,7 +1998,7 @@ def get_exponent_max(value, dtype_max, default_max = 5): scalar_val = numpy.random.randint(-65535, 65535) self.numpy_tensor = numpy.array(scalar_val, dtype=self.dtype) else: - scalar_val = numpy.random.random() - 0.5 + scalar_val = (numpy.random.random() - 0.5)*1.2 self.numpy_tensor = numpy.array(scalar_val, dtype=self.dtype) elif USE_CACHED_NUMPY and self.dtype not in ["int64", "float64"]: self.numpy_tensor = self.get_cached_numpy(self.dtype, self.shape) @@ -2006,7 +2006,7 @@ def get_exponent_max(value, dtype_max, default_max = 5): if "int" in self.dtype: self.numpy_tensor = (numpy.random.randint(-65535, 65535, size=self.shape)).astype(self.dtype) else: - self.numpy_tensor = (numpy.random.random(self.shape) - 0.5).astype(self.dtype) + self.numpy_tensor = ((numpy.random.random(self.shape) - 0.5)*1.2).astype(self.dtype) self.dtype = original_dtype return self.numpy_tensor diff --git a/tools/get_cases_from_csv.py b/tools/get_cases_from_csv.py index e58a9a97..0af99ef9 100644 --- a/tools/get_cases_from_csv.py +++ b/tools/get_cases_from_csv.py @@ -4,7 +4,8 @@ app = typer.Typer() -def _get_cases(api_name: str, original_csv: str): + +def _get_cases(api_name: str, only_diff: bool, original_csv: str): with ( open(original_csv, "r") as infile, open(f"filtered_result_{api_name}.csv", "w", newline="") as outfile, @@ -26,7 +27,7 @@ def _get_cases(api_name: str, original_csv: str): last_col = float(row[-1]) if row[-1].strip() else 0 second_last_col = float(row[-2]) if row[-2].strip() else 0 - if last_col < 1e-16 and second_last_col < 1e-16: + if only_diff and last_col < 1e-16 and second_last_col < 1e-16: continue writer.writerow(row) @@ -44,22 +45,19 @@ def _get_cases(api_name: str, original_csv: str): first_col = row[0] last_col = float(row[-1]) if row[-1].strip() else 0 second_last_col = float(row[-2]) if row[-2].strip() else 0 - - if last_col < 1e-16 and second_last_col < 1e-16: - continue outs.append(row[2].replace('""', '"')) - outfile.write("\n".join(outs)) + outfile.write("\n".join(set(outs))) @app.command() def get_cases( api_names: list[str], + only_diff: bool = False, config_path: Path = Path("TotalStableFull.csv"), ): for api_name in api_names: - _get_cases(api_name, config_path.as_posix()) - + _get_cases(api_name, only_diff, config_path.as_posix()) if __name__ == "__main__": - app() \ No newline at end of file + app()