1 /**
2  * @file methods/hmm/hmm_generate_main.cpp
3  * @author Ryan Curtin
4  * @author Michael Fox
5  *
6  * Compute the most probably hidden state sequence of a given observation
7  * sequence for a given HMM.
8  *
9  * mlpack is free software; you may redistribute it and/or modify it under the
10  * terms of the 3-clause BSD license.  You should have received a copy of the
11  * 3-clause BSD license along with mlpack.  If not, see
12  * http://www.opensource.org/licenses/BSD-3-Clause for more information.
13  */
14 #include <mlpack/prereqs.hpp>
15 #include <mlpack/core/util/io.hpp>
16 #include <mlpack/core/util/mlpack_main.hpp>
17 
18 #include "hmm.hpp"
19 #include "hmm_model.hpp"
20 
21 #include <mlpack/methods/gmm/gmm.hpp>
22 #include <mlpack/methods/gmm/diagonal_gmm.hpp>
23 
24 using namespace mlpack;
25 using namespace mlpack::hmm;
26 using namespace mlpack::distribution;
27 using namespace mlpack::util;
28 using namespace mlpack::gmm;
29 using namespace mlpack::math;
30 using namespace arma;
31 using namespace std;
32 
33 // Program Name.
34 BINDING_NAME("Hidden Markov Model (HMM) Sequence Generator");
35 
36 // Short description.
37 BINDING_SHORT_DESC(
38     "A utility to generate random sequences from a pre-trained Hidden Markov "
39     "Model (HMM).  The length of the desired sequence can be specified, and a "
40     "random sequence of observations is returned.");
41 
42 // Long description.
43 BINDING_LONG_DESC(
44     "This utility takes an already-trained HMM, specified as the " +
45     PRINT_PARAM_STRING("model") + " parameter, and generates a random "
46     "observation sequence and hidden state sequence based on its parameters. "
47     "The observation sequence may be saved with the " +
48     PRINT_PARAM_STRING("output") + " output parameter, and the internal state "
49     " sequence may be saved with the " + PRINT_PARAM_STRING("state") + " output"
50     " parameter."
51     "\n\n"
52     "The state to start the sequence in may be specified with the " +
53     PRINT_PARAM_STRING("start_state") + " parameter.");
54 
55 // Example.
56 BINDING_EXAMPLE(
57     "For example, to generate a sequence of length 150 from the HMM " +
58     PRINT_MODEL("hmm") + " and save the observation sequence to " +
59     PRINT_DATASET("observations") + " and the hidden state sequence to " +
60     PRINT_DATASET("states") + ", the following command may be used: "
61     "\n\n" +
62     PRINT_CALL("hmm_generate", "model", "hmm", "length", 150, "output",
63         "observations", "state", "states"));
64 
65 // See also...
66 BINDING_SEE_ALSO("@hmm_train", "#hmm_train");
67 BINDING_SEE_ALSO("@hmm_loglik", "#hmm_loglik");
68 BINDING_SEE_ALSO("@hmm_viterbi", "#hmm_viterbi");
69 BINDING_SEE_ALSO("Hidden Mixture Models on Wikipedia",
70         "https://en.wikipedia.org/wiki/Hidden_Markov_model");
71 BINDING_SEE_ALSO("mlpack::hmm::HMM class documentation",
72         "@doxygen/classmlpack_1_1hmm_1_1HMM.html");
73 
74 PARAM_MODEL_IN_REQ(HMMModel, "model", "Trained HMM to generate sequences with.",
75     "m");
76 PARAM_INT_IN_REQ("length", "Length of sequence to generate.", "l");
77 
78 PARAM_INT_IN("start_state", "Starting state of sequence.", "t", 0);
79 PARAM_MATRIX_OUT("output", "Matrix to save observation sequence to.", "o");
80 PARAM_UMATRIX_OUT("state", "Matrix to save hidden state sequence to.", "S");
81 PARAM_INT_IN("seed", "Random seed.  If 0, 'std::time(NULL)' is used.", "s", 0);
82 
83 // Because we don't know what the type of our HMM is, we need to write a
84 // function which can take arbitrary HMM types.
85 struct Generate
86 {
87   template<typename HMMType>
ApplyGenerate88   static void Apply(HMMType& hmm, void* /* extraInfo */)
89   {
90     mat observations;
91     Row<size_t> sequence;
92 
93     RequireParamValue<int>("start_state", [](int x) { return x >= 0; }, true,
94         "Invalid start state");
95     RequireParamValue<int>("length", [](int x) { return x >= 0; }, true,
96         "Length must be >= 0");
97 
98     // Load the parameters.
99     const size_t startState = (size_t) IO::GetParam<int>("start_state");
100     const size_t length = (size_t) IO::GetParam<int>("length");
101 
102     Log::Info << "Generating sequence of length " << length << "..." << endl;
103     if (startState >= hmm.Transition().n_rows)
104     {
105       Log::Fatal << "Invalid start state (" << startState << "); must be "
106           << "between 0 and number of states (" << hmm.Transition().n_rows
107           << ")!" << endl;
108     }
109 
110     hmm.Generate(length, observations, sequence, startState);
111 
112     // Now save the output.
113     if (IO::HasParam("output"))
114       IO::GetParam<mat>("output") = std::move(observations);
115 
116     // Do we want to save the hidden sequence?
117     if (IO::HasParam("state"))
118       IO::GetParam<Mat<size_t>>("state") = std::move(sequence);
119   }
120 };
121 
mlpackMain()122 static void mlpackMain()
123 {
124   RequireAtLeastOnePassed({ "output", "state" }, false, "no output will be "
125       "saved");
126 
127   // Set random seed.
128   if (IO::GetParam<int>("seed") != 0)
129     RandomSeed((size_t) IO::GetParam<int>("seed"));
130   else
131     RandomSeed((size_t) time(NULL));
132 
133   // Load model, and perform the generation.
134   HMMModel* hmm;
135   hmm = std::move(IO::GetParam<HMMModel*>("model"));
136   hmm->PerformAction<Generate, void>(NULL); // No extra data required.
137 }
138