diff options
| -rw-r--r-- | rumba/elements/experimentation.py | 30 | 
1 files changed, 15 insertions, 15 deletions
diff --git a/rumba/elements/experimentation.py b/rumba/elements/experimentation.py index 0f734f8..1c27dff 100644 --- a/rumba/elements/experimentation.py +++ b/rumba/elements/experimentation.py @@ -326,7 +326,7 @@ class Experiment(object):              # breadth-first traversal.              enrolled = {first}              frontier = {first} -            dt_connected = set() +            edges_covered = set()              while len(frontier):                  cur = frontier.pop()                  for edge in dif_graphs[dif][cur]: @@ -336,16 +336,15 @@ class Experiment(object):                          assert(enrollee is not None)                          enroller = cur.get_ipcp_by_dif(dif)                          assert(enroller is not None) -                        if self.enrollment_strategy == 'minimal': -                            self.enrollments[-1].append({'dif': dif, -                                                         'enrollee': enrollee, -                                                         'enroller': enroller, -                                                         'lower_dif': edge[1]}) +                        edges_covered.add((enrollee, enroller)) +                        self.enrollments[-1].append({'dif': dif, +                                                     'enrollee': enrollee, +                                                     'enroller': enroller, +                                                     'lower_dif': edge[1]})                          self.mgmt_flows[-1].append({'src': enrollee,                                                      'dst': enroller})                          self.dt_flows[-1].append({'src': enrollee,                                                    'dst': enroller}) -                        dt_connected.add((enrollee, enroller))                          frontier.add(edge[0])              if len(dif.members) != len(enrolled):                  raise Exception("Disconnected DIF found: %s" % (dif,)) @@ -358,16 +357,17 @@ class Experiment(object):                          assert(enrollee is not None)                          enroller = edge[0].get_ipcp_by_dif(dif)                          assert(enroller is not None) -                        if self.enrollment_strategy == 'full-mesh': -                            self.enrollments[-1].append({'dif': dif, -                                                         'enrollee': enrollee, -                                                         'enroller': enroller, -                                                         'lower_dif': edge[1]}) -                        if self.dt_strategy == 'full-mesh': -                            if ((enrollee, enroller) not in dt_connected and -                                    (enroller, enrollee) not in dt_connected): +                        if ((enrollee, enroller) not in edges_covered and +                            (enroller, enrollee) not in edges_covered): +                            if self.enrollment_strategy == 'full-mesh': +                                self.enrollments[-1].append({'dif': dif, +                                                             'enrollee': enrollee, +                                                             'enroller': enroller, +                                                             'lower_dif': edge[1]}) +                            if self.dt_strategy == 'full-mesh':                                  self.dt_flows[-1].append({'src': enrollee,                                                            'dst': enroller}) +                            edges_covered.add((enrollee, enroller))              if not (self.dt_strategy == 'minimal'                      or self.dt_strategy == 'full-mesh') \  | 
