# coding=utf-8 # Taken from https://github.com/google/compare_gan/blob/master/compare_gan/src/prd_score_test.py # Copyright 2018 Google LLC & Hwalsuk Lee. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Testing precision and recall computation on synthetic data.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import unittest from compare_gan.src import prd_score as prd import numpy as np class PRDTest(unittest.TestCase): def test_compute_prd_no_overlap(self): eval_dist = [0, 1] ref_dist = [1, 0] result = np.ravel(prd.compute_prd(eval_dist, ref_dist)) np.testing.assert_almost_equal(result, 0) def test_compute_prd_perfect_overlap(self): eval_dist = [1, 0] ref_dist = [1, 0] result = prd.compute_prd(eval_dist, ref_dist, num_angles=11) np.testing.assert_almost_equal([result[0][5], result[1][5]], [1, 1]) def test_compute_prd_low_precision_high_recall(self): eval_dist = [0.5, 0.5] ref_dist = [1, 0] result = prd.compute_prd(eval_dist, ref_dist, num_angles=11) np.testing.assert_almost_equal(result[0][5], 0.5) np.testing.assert_almost_equal(result[1][5], 0.5) np.testing.assert_almost_equal(result[0][10], 0.5) np.testing.assert_almost_equal(result[1][1], 1) def test_compute_prd_high_precision_low_recall(self): eval_dist = [1, 0] ref_dist = [0.5, 0.5] result = prd.compute_prd(eval_dist, ref_dist, num_angles=11) np.testing.assert_almost_equal([result[0][5], result[1][5]], [0.5, 0.5]) np.testing.assert_almost_equal(result[1][1], 0.5) np.testing.assert_almost_equal(result[0][10], 1) def test_compute_prd_bad_epsilon(self): with self.assertRaises(ValueError): prd.compute_prd([1], [1], epsilon=0) with self.assertRaises(ValueError): prd.compute_prd([1], [1], epsilon=1) with self.assertRaises(ValueError): prd.compute_prd([1], [1], epsilon=-1) def test_compute_prd_bad_num_angles(self): with self.assertRaises(ValueError): prd.compute_prd([1], [1], num_angles=0) with self.assertRaises(ValueError): prd.compute_prd([1], [1], num_angles=1) with self.assertRaises(ValueError): prd.compute_prd([1], [1], num_angles=-1) with self.assertRaises(ValueError): prd.compute_prd([1], [1], num_angles=1e6+1) with self.assertRaises(ValueError): prd.compute_prd([1], [1], num_angles=2.5) def test__cluster_into_bins(self): eval_data = np.zeros([5, 4]) ref_data = np.ones([5, 4]) result = prd._cluster_into_bins(eval_data, ref_data, 3) self.assertEqual(len(result), 2) self.assertEqual(len(result[0]), 3) self.assertEqual(len(result[1]), 3) np.testing.assert_almost_equal(sum(result[0]), 1) np.testing.assert_almost_equal(sum(result[1]), 1) def test_compute_prd_from_embedding_mismatch_num_samples_should_fail(self): # Mismatch in number of samples with enforce_balance set to True with self.assertRaises(ValueError): prd.compute_prd_from_embedding( np.array([[0], [0], [1]]), np.array([[0], [1]]), num_clusters=2, enforce_balance=True) def test_compute_prd_from_embedding_mismatch_num_samples_should_work(self): # Mismatch in number of samples with enforce_balance set to False try: prd.compute_prd_from_embedding( np.array([[0], [0], [1]]), np.array([[0], [1]]), num_clusters=2, enforce_balance=False) except ValueError: self.fail( 'compute_prd_from_embedding should not raise a ValueError when ' 'enforce_balance is set to False.') def test__prd_to_f_beta_correct_computation(self): precision = np.array([1, 1, 0, 0, 0.5, 1, 0.5]) recall = np.array([1, 0, 1, 0, 0.5, 0.5, 1]) expected = np.array([1, 0, 0, 0, 0.5, 2/3, 2/3]) with np.errstate(invalid='ignore'): result = prd._prd_to_f_beta(precision, recall, beta=1) np.testing.assert_almost_equal(result, expected) expected = np.array([1, 0, 0, 0, 0.5, 5/9, 5/6]) with np.errstate(invalid='ignore'): result = prd._prd_to_f_beta(precision, recall, beta=2) np.testing.assert_almost_equal(result, expected) expected = np.array([1, 0, 0, 0, 0.5, 5/6, 5/9]) with np.errstate(invalid='ignore'): result = prd._prd_to_f_beta(precision, recall, beta=1/2) np.testing.assert_almost_equal(result, expected) result = prd._prd_to_f_beta(np.array([]), np.array([]), beta=1) expected = np.array([]) np.testing.assert_almost_equal(result, expected) def test__prd_to_f_beta_bad_beta(self): with self.assertRaises(ValueError): prd._prd_to_f_beta(np.ones(1), np.ones(1), beta=0) with self.assertRaises(ValueError): prd._prd_to_f_beta(np.ones(1), np.ones(1), beta=-3) def test__prd_to_f_beta_bad_precision_or_recall(self): with self.assertRaises(ValueError): prd._prd_to_f_beta(-np.ones(1), np.ones(1), beta=1) with self.assertRaises(ValueError): prd._prd_to_f_beta(np.ones(1), -np.ones(1), beta=1) def test_plot_not_enough_labels(self): with self.assertRaises(ValueError): prd.plot(np.zeros([3, 2, 5]), labels=['1', '2']) def test_plot_too_many_labels(self): with self.assertRaises(ValueError): prd.plot(np.zeros([1, 2, 5]), labels=['1', '2', '3']) if __name__ == '__main__': unittest.main()