pasp.wlearn
1import numpy as np 2 3def learn(P, D: np.ndarray, A: np.ndarray = None, niters: int = 30, alg: str = "em", 4 lr: float = 0.001, batch: int = None, smoothing: float = 1e-4, lstable_sat: bool = True, 5 display: str = "loglikelihood"): 6 if alg == "fixpoint": alg = "em" 7 # If batch is not given, set batch to the size of the dataset. 8 if batch is None: batch = len(D) 9 # Prepare training tensors. 10 for N in P.NR: N.prepare_train(batch) 11 for N in P.NA: N.prepare_train(batch) 12 # Check if data dimensions all match. 13 def assert_dims(N, e: int): 14 if N.learnable and ((o := len(N.data[0].train)) != e): 15 raise ValueError(f"Training data dimensions do not match!\n Expected {n}, got {o}.") 16 n = len(D) 17 for N in P.NR: assert_dims(N, n) 18 for N in P.NA: assert_dims(N, n) 19 20 # Check if D is an np.ndarray or list. 21 if not (issubclass(t := type(D), list) or issubclass(t, np.ndarray)): 22 raise TypeError(f"Expected dataset of type list or numpy.ndarray, got {t}!") 23 24 # Batch mode. 25 if A is None: 26 if type(D) is np.ndarray: data = D if np.issubdtype(D.dtype, bytes) else D.astype(bytes) 27 else: data = D 28 from learn import learn_batch as clearn_batch 29 P.train() 30 clearn_batch(P, data, niters = niters, alg = alg, lr = lr, batch = batch, 31 lstable_sat = lstable_sat, display = display, smoothing = smoothing) 32 P.eval() 33 return 34 35 # Non-batch mode. 36 if type(A) is not np.ndarray: atoms = np.array(A, dtype = bytes) 37 else: atoms = A if np.issubdtype(A.dtype, bytes) else A.astype(bytes) 38 if type(D) is not np.ndarray: data = np.array(D, dtype = np.uint8) 39 else: data = D if np.issubdtype(D.dtype, np.uint8) else D.astype(np.uint8) 40 41 obs, obs_counts = np.unique(data, axis = 0, return_counts = True) 42 from learn import learn as clearn 43 P.train() 44 clearn(P, obs, obs_counts, atoms, niters = niters, alg = alg, lr = lr, lstable_sat = lstable_sat) 45 P.eval()
def
learn( P, D: numpy.ndarray, A: numpy.ndarray = None, niters: int = 30, alg: str = 'em', lr: float = 0.001, batch: int = None, smoothing: float = 0.0001, lstable_sat: bool = True, display: str = 'loglikelihood'):
4def learn(P, D: np.ndarray, A: np.ndarray = None, niters: int = 30, alg: str = "em", 5 lr: float = 0.001, batch: int = None, smoothing: float = 1e-4, lstable_sat: bool = True, 6 display: str = "loglikelihood"): 7 if alg == "fixpoint": alg = "em" 8 # If batch is not given, set batch to the size of the dataset. 9 if batch is None: batch = len(D) 10 # Prepare training tensors. 11 for N in P.NR: N.prepare_train(batch) 12 for N in P.NA: N.prepare_train(batch) 13 # Check if data dimensions all match. 14 def assert_dims(N, e: int): 15 if N.learnable and ((o := len(N.data[0].train)) != e): 16 raise ValueError(f"Training data dimensions do not match!\n Expected {n}, got {o}.") 17 n = len(D) 18 for N in P.NR: assert_dims(N, n) 19 for N in P.NA: assert_dims(N, n) 20 21 # Check if D is an np.ndarray or list. 22 if not (issubclass(t := type(D), list) or issubclass(t, np.ndarray)): 23 raise TypeError(f"Expected dataset of type list or numpy.ndarray, got {t}!") 24 25 # Batch mode. 26 if A is None: 27 if type(D) is np.ndarray: data = D if np.issubdtype(D.dtype, bytes) else D.astype(bytes) 28 else: data = D 29 from learn import learn_batch as clearn_batch 30 P.train() 31 clearn_batch(P, data, niters = niters, alg = alg, lr = lr, batch = batch, 32 lstable_sat = lstable_sat, display = display, smoothing = smoothing) 33 P.eval() 34 return 35 36 # Non-batch mode. 37 if type(A) is not np.ndarray: atoms = np.array(A, dtype = bytes) 38 else: atoms = A if np.issubdtype(A.dtype, bytes) else A.astype(bytes) 39 if type(D) is not np.ndarray: data = np.array(D, dtype = np.uint8) 40 else: data = D if np.issubdtype(D.dtype, np.uint8) else D.astype(np.uint8) 41 42 obs, obs_counts = np.unique(data, axis = 0, return_counts = True) 43 from learn import learn as clearn 44 P.train() 45 clearn(P, obs, obs_counts, atoms, niters = niters, alg = alg, lr = lr, lstable_sat = lstable_sat) 46 P.eval()