diff --git a/api/auth_api.py b/api/auth_api.py index a3dff3d..2247d64 100644 --- a/api/auth_api.py +++ b/api/auth_api.py @@ -13,12 +13,13 @@ from functools import wraps from random import randint from flask_login import logout_user, login_user +from typing import Iterable from werkzeug.routing import BuildError from backend import db, app from backend.api import auth_api_bp from backend.auth import AUTH_PROVIDERS, oidc_auth -from backend.models.user_model import User +from backend.models.user_model import User, Group def create_jwt(user: User, validity_min=30): @@ -68,20 +69,40 @@ def login(): return jsonify({'token': token.decode('UTF-8')}) +def check_and_create_groups(groups: Iterable[str]): + user_groups = [] + for g in groups: + group = Group.get_by_name(g) + if group is None: + group = Group(name=g) + db.session.add(group) + user_groups.append(group) + + db.session.commit() + return user_groups + def create_or_retrieve_user_from_userinfo(userinfo): try: email = userinfo["email"] except KeyError: return None + + user_groups = check_and_create_groups(groups=userinfo.get("memberOf", [])) user = User.get_by_identifier(email) if user is not None: - app.logger.info("user found") + app.logger.info("user found -> update user") + user.first_name = userinfo.get("given_name", "") + user.last_name = userinfo.get("family_name", "") + for g in user_groups: + user.groups.append(g) + db.session.commit() return user user = User(email=email, first_name=userinfo.get("given_name", ""), - last_name=userinfo.get("family_name", "")) + last_name=userinfo.get("family_name", ""), external_user=True, + groups=userinfo.get("memberOf", [])) app.logger.info("creating new user") @@ -93,7 +114,10 @@ def create_or_retrieve_user_from_userinfo(userinfo): @auth_api_bp.route('/oidc', methods=['GET']) @oidc_auth.oidc_auth() def oidc(): + user = create_or_retrieve_user_from_userinfo(flask.session['userinfo']) + + return jsonify(user.to_dict()) if user is None: return "Could not authenticate: could not find or create user.", 401 if current_app.config.get("AUTH_RETURN_EXTERNAL_JWT", False): diff --git a/app.db b/app.db index 42ae2c9..c8eb8e4 100644 Binary files a/app.db and b/app.db differ diff --git a/auth/config.py b/auth/config.py index e6c099d..75e647f 100644 --- a/auth/config.py +++ b/auth/config.py @@ -14,12 +14,12 @@ AUTH_PROVIDERS: Dict[str, Dict[str, str]] = { "KIT OIDC (API)": { "type": "api_oidc", - "url": "auth_api_bp.oidc" + "url": "auth_api.oidc" }, "User-Password (API)": { "type": "api_login_form", - "url": "auth_api_bp.base_login" + "url": "auth_api.login" }, } diff --git a/database/database.py b/database/database.py index 77453a6..bd89b86 100644 --- a/database/database.py +++ b/database/database.py @@ -15,6 +15,6 @@ def init_db(): # import all modules here that might define models so that # they will be registered properly on the metadata. Otherwise # you will have to import them first before calling init_db() - import app.models.user - import app.models.lock + import backend.app.models.user + import backend.app.models.lock metadata.create_all(bind=engine) diff --git a/migrations/README b/migrations/README new file mode 100644 index 0000000..98e4f9c --- /dev/null +++ b/migrations/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/migrations/alembic.ini b/migrations/alembic.ini new file mode 100644 index 0000000..f8ed480 --- /dev/null +++ b/migrations/alembic.ini @@ -0,0 +1,45 @@ +# A generic, single database configuration. + +[alembic] +# template used to generate migration files +# file_template = %%(rev)s_%%(slug)s + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/migrations/env.py b/migrations/env.py new file mode 100644 index 0000000..169d487 --- /dev/null +++ b/migrations/env.py @@ -0,0 +1,95 @@ +from __future__ import with_statement + +import logging +from logging.config import fileConfig + +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +fileConfig(config.config_file_name) +logger = logging.getLogger('alembic.env') + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +from flask import current_app +config.set_main_option('sqlalchemy.url', + current_app.config.get('SQLALCHEMY_DATABASE_URI')) +target_metadata = current_app.extensions['migrate'].db.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline(): + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, target_metadata=target_metadata, literal_binds=True + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online(): + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + + # this callback is used to prevent an auto-migration from being generated + # when there are no changes to the schema + # reference: http://alembic.zzzcomputing.com/en/latest/cookbook.html + def process_revision_directives(context, revision, directives): + if getattr(config.cmd_opts, 'autogenerate', False): + script = directives[0] + if script.upgrade_ops.is_empty(): + directives[:] = [] + logger.info('No changes in schema detected.') + + connectable = engine_from_config( + config.get_section(config.config_ini_section), + prefix='sqlalchemy.', + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, + target_metadata=target_metadata, + process_revision_directives=process_revision_directives, + **current_app.extensions['migrate'].configure_args + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/migrations/script.py.mako b/migrations/script.py.mako new file mode 100644 index 0000000..2c01563 --- /dev/null +++ b/migrations/script.py.mako @@ -0,0 +1,24 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade(): + ${upgrades if upgrades else "pass"} + + +def downgrade(): + ${downgrades if downgrades else "pass"} diff --git a/migrations/versions/6f980d1e7ac5_.py b/migrations/versions/6f980d1e7ac5_.py new file mode 100644 index 0000000..1d5ffa6 --- /dev/null +++ b/migrations/versions/6f980d1e7ac5_.py @@ -0,0 +1,35 @@ +"""empty message + +Revision ID: 6f980d1e7ac5 +Revises: +Create Date: 2019-04-02 13:33:29.319719 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '6f980d1e7ac5' +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('groups', + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('name', sa.Unicode(length=64), nullable=False), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('name') + ) + op.add_column('user', sa.Column('external_user', sa.Boolean(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('user', 'external_user') + op.drop_table('groups') + # ### end Alembic commands ### diff --git a/models/user_model.py b/models/user_model.py index 6c7e489..077b0bd 100644 --- a/models/user_model.py +++ b/models/user_model.py @@ -2,8 +2,10 @@ """ Example user model and related models """ +from sqlalchemy.orm import relation +from sqlalchemy import MetaData -from backend import db, app +from backend import db, app, login_manager from backend.models.post_model import Post from backend.models.example_model import ExampleDataItem import re @@ -14,6 +16,9 @@ from datetime import datetime, timedelta from passlib.hash import sha256_crypt from hashlib import md5 + +metadata = MetaData() + followers = db.Table('followers', db.Column('follower_id', db.Integer, db.ForeignKey('user.id')), db.Column('followed_id', db.Integer, db.ForeignKey('user.id')) @@ -24,6 +29,20 @@ acquaintances = db.Table('acquaintances', db.Column('acquaintance_id', db.Integer, db.ForeignKey('user.id')) ) +# This is the association table for the many-to-many relationship between +# groups and members - this is, the memberships. +user_group_table = db.Table('user_group', + db.Column('user_id', db.Integer, + db.ForeignKey('user.id', + onupdate="CASCADE", + ondelete="CASCADE"), + primary_key=True), + db.Column('group_id', db.Integer, + db.ForeignKey('group.id', + onupdate="CASCADE", + ondelete="CASCADE"), + primary_key=True)) + class User(UserMixin, db.Model): """ @@ -42,8 +61,10 @@ class User(UserMixin, db.Model): example_data_item_id = db.Column(db.ForeignKey(ExampleDataItem.id)) about_me = db.Column(db.String(140)) role = db.Column(db.String(64)) + groups = db.relationship('Group', secondary=user_group_table, back_populates='users') password = db.Column(db.String(255), nullable=True) registered_on = db.Column(db.DateTime, nullable=False, default=datetime.utcnow()) + external_user = db.Column(db.Boolean, default=False) last_seen = db.Column(db.DateTime, default=datetime.utcnow()) jwt_exp_delta_seconds = db.Column(db.Integer, nullable=True) acquainted = db.relationship('User', @@ -62,11 +83,21 @@ class User(UserMixin, db.Model): def __init__(self, **kwargs): super(User, self).__init__(**kwargs) password = kwargs.get("password", None) + external_user = kwargs.get("external_user", None) + groups = kwargs.get("groups", None) if password is not None: self.password = sha256_crypt.encrypt(password) - # do custom initialization here + if external_user is not None: + self.external_user = external_user + if groups is not None: + if isinstance(groups, list): + for g in groups: + self.groups.append(g) + elif isinstance(groups, str): + self.groups.append(groups) @staticmethod + @login_manager.user_loader def get_by_identifier(identifier): """ Find user by identifier, which might be the nickname or the e-mail address. @@ -152,6 +183,15 @@ class User(UserMixin, db.Model): # TODO: implement correctly return False + @property + def is_read_only(self): + """ + Returns true if user is active. + :return: + """ + # TODO: implement correctly + return True + @staticmethod def decode_auth_token(auth_token): """ @@ -307,7 +347,7 @@ class User(UserMixin, db.Model): followers.c.follower_id == self.id).order_by(Post.timestamp.desc()) def to_dict(self): - return dict(id=self.id, email=self.email) + return dict(id=self.id, email=self.email, groups=self.groups) def __repr__(self): return '' % self.nickname @@ -342,3 +382,24 @@ class BlacklistToken(db.Model): return True else: return False + + +class Group(db.Model): + def __init__(self, **kwargs): + super(Group, self).__init__(**kwargs) + + id = db.Column(db.Integer, autoincrement=True, primary_key=True) + name = db.Column(db.Unicode(64), unique=True, nullable=False) + users = db.relationship('User', secondary=user_group_table, back_populates='groups') + + @staticmethod + def get_by_name(name): + """ + Find group by name + :param name: + :return: + """ + return Group.query.filter(Group.name == name).first() + + def __str__(self): + return self.name