From 7a90e844b7807fe5e6ceb3eb5141874c632c1c52 Mon Sep 17 00:00:00 2001 From: Hugo Sales Date: Fri, 23 Apr 2021 12:54:25 +0000 Subject: [PATCH] [SECURITY][DB] Make user register 'atomic', by using a single transaction for inserting all objects, to avoid partial inserts --- src/Controller/Security.php | 19 ++++++++----------- src/Core/DB/DB.php | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/src/Controller/Security.php b/src/Controller/Security.php index eb330d53c3..06b5f1456a 100644 --- a/src/Controller/Security.php +++ b/src/Controller/Security.php @@ -93,17 +93,19 @@ class Security extends Controller try { $actor = GSActor::create(['nickname' => $data['nickname']]); - DB::persist($actor); - DB::flush(); - $id = $actor->getId(); - $user = LocalUser::create([ - 'id' => $id, + $user = LocalUser::create([ 'nickname' => $data['nickname'], 'outgoing_email' => $data['email'], 'incoming_email' => $data['email'], 'password' => LocalUser::hashPassword($data['password']), ]); - DB::persist($user); + DB::persistWithSameId( + $actor, + $user, + // Self follow + fn (int $id) => DB::persist(Follow::create(['follower' => $id, 'followed' => $id])) + ); + DB::flush(); } catch (UniqueConstraintViolationException $e) { throw new NicknameTakenException; } @@ -123,11 +125,6 @@ class Security extends Controller $user->setIsEmailVerified(true); } - // Self follow - $follow = Follow::create(['follower' => $id, 'followed' => $id]); - DB::persist($follow); - DB::flush(); - return $guard_handler->authenticateUserAndHandleSuccess( $user, $request, diff --git a/src/Core/DB/DB.php b/src/Core/DB/DB.php index c21f6fe495..b2586a1225 100644 --- a/src/Core/DB/DB.php +++ b/src/Core/DB/DB.php @@ -173,6 +173,24 @@ abstract class DB return $repo->count($table, $criteria); } + /** + * Insert all given objects with the generated ID of the first one + */ + public static function persistWithSameId(object $owner, object | array $others, ?callable $extra = null) + { + $conn = self::getConnection(); + $metadata = self::getClassMetadata(get_class($owner)); + $seqName = $metadata->getSequenceName($conn->getDatabasePlatform()); + self::persist($owner); + $id = $conn->lastInsertId($seqName); + F\map(is_array($others) ? $others : [$others], function ($o) use ($id) { $o->setId($id); self::persist($o); }); + if (!is_null($extra)) { + $extra($id); + } + self::flush(); + return $id; + } + /** * Intercept static function calls to allow refering to entities * without writing the namespace (which is deduced from the call